mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into fix/openrouter-null-cache-tokens
This commit is contained in:
@@ -48,14 +48,15 @@ git diff "$BASE_BRANCH"...HEAD -- src/ | head -500
|
||||
For each changed file, determine:
|
||||
|
||||
1. **Is it a page?** (`page.tsx`) — these are the primary test targets
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page that uses it
|
||||
2. **Is it a hook?** (`use*.ts`) — test via the page/component that uses it; avoid direct `renderHook()` tests unless it is a shared reusable hook with standalone business logic
|
||||
3. **Is it a component?** (`.tsx` in `components/`) — test via the parent page unless it's complex enough to warrant isolation
|
||||
4. **Is it a helper?** (`helpers.ts`, `utils.ts`) — unit test directly if pure logic
|
||||
|
||||
**Priority order:**
|
||||
|
||||
1. Pages with new/changed data fetching or user interactions
|
||||
2. Components with complex internal logic (modals, forms, wizards)
|
||||
3. Hooks with non-trivial business logic
|
||||
3. Shared hooks with standalone business logic when UI-level coverage is impractical
|
||||
4. Pure helper functions
|
||||
|
||||
Skip: styling-only changes, type-only changes, config changes.
|
||||
@@ -163,6 +164,7 @@ describe("LibraryPage", () => {
|
||||
- Use `waitFor` when asserting side effects or state changes after interactions
|
||||
- Import `fireEvent` or `userEvent` from the test-utils for interactions
|
||||
- Do NOT mock internal hooks or functions — mock at the API boundary via MSW
|
||||
- Prefer Orval-generated MSW handlers and response builders over hand-built API response objects
|
||||
- Do NOT use `act()` manually — `render` and `fireEvent` handle it
|
||||
- Keep tests focused: one behavior per test
|
||||
- Use descriptive test names that read like sentences
|
||||
@@ -190,9 +192,7 @@ import { http, HttpResponse } from "msw";
|
||||
server.use(
|
||||
http.get("http://localhost:3000/api/proxy/api/v2/library/agents", () => {
|
||||
return HttpResponse.json({
|
||||
agents: [
|
||||
{ id: "1", name: "Test Agent", description: "A test agent" },
|
||||
],
|
||||
agents: [{ id: "1", name: "Test Agent", description: "A test agent" }],
|
||||
pagination: { total_items: 1, total_pages: 1, page: 1, page_size: 10 },
|
||||
});
|
||||
}),
|
||||
@@ -211,6 +211,7 @@ pnpm test:unit --reporter=verbose
|
||||
```
|
||||
|
||||
If tests fail:
|
||||
|
||||
1. Read the error output carefully
|
||||
2. Fix the test (not the source code, unless there is a genuine bug)
|
||||
3. Re-run until all pass
|
||||
|
||||
13
.github/workflows/platform-fullstack-ci.yml
vendored
13
.github/workflows/platform-fullstack-ci.yml
vendored
@@ -160,6 +160,7 @@ jobs:
|
||||
run: |
|
||||
cp ../backend/.env.default ../backend/.env
|
||||
echo "OPENAI_INTERNAL_API_KEY=${{ secrets.OPENAI_API_KEY }}" >> ../backend/.env
|
||||
echo "SCHEDULER_STARTUP_EMBEDDING_BACKFILL=false" >> ../backend/.env
|
||||
env:
|
||||
# Used by E2E test data script to generate embeddings for approved store agents
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
@@ -288,6 +289,14 @@ jobs:
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: autogpt_platform/frontend/pnpm-lock.yaml
|
||||
|
||||
- name: Set up tests - Cache Playwright browsers
|
||||
uses: actions/cache@v5
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: playwright-${{ runner.os }}-${{ hashFiles('autogpt_platform/frontend/pnpm-lock.yaml') }}
|
||||
restore-keys: |
|
||||
playwright-${{ runner.os }}-
|
||||
|
||||
- name: Copy source maps from Docker for E2E coverage
|
||||
run: |
|
||||
FRONTEND_CONTAINER=$(docker compose -f ../docker-compose.resolved.yml ps -q frontend)
|
||||
@@ -299,8 +308,8 @@ jobs:
|
||||
- name: Set up tests - Install browser 'chromium'
|
||||
run: pnpm playwright install --with-deps chromium
|
||||
|
||||
- name: Run Playwright tests
|
||||
run: pnpm test:no-build
|
||||
- name: Run Playwright E2E suite
|
||||
run: pnpm test:e2e:no-build
|
||||
continue-on-error: false
|
||||
|
||||
- name: Upload E2E coverage to Codecov
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -194,3 +194,4 @@ test.db
|
||||
.next
|
||||
# Implementation plans (generated by AI agents)
|
||||
plans/
|
||||
.claude/worktrees/
|
||||
|
||||
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
166
autogpt_platform/backend/agents/calculator-agent.json
Normal file
@@ -0,0 +1,166 @@
|
||||
{
|
||||
"id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"version": 2,
|
||||
"is_active": true,
|
||||
"name": "Calculator agent",
|
||||
"description": "",
|
||||
"instructions": null,
|
||||
"recommended_schedule_cron": null,
|
||||
"forked_from_id": null,
|
||||
"forked_from_version": null,
|
||||
"user_id": "",
|
||||
"created_at": "2026-04-13T03:45:11.241Z",
|
||||
"nodes": [
|
||||
{
|
||||
"id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"block_id": "c0a8e994-ebf1-4a9c-a4d8-89d09c86741b",
|
||||
"input_default": {
|
||||
"name": "Input",
|
||||
"secret": false,
|
||||
"advanced": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": -188.2244873046875,
|
||||
"y": 95
|
||||
}
|
||||
},
|
||||
"input_links": [],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"block_id": "363ae599-353e-4804-937e-b2ee3cef3da4",
|
||||
"input_default": {
|
||||
"name": "Output",
|
||||
"secret": false,
|
||||
"advanced": false,
|
||||
"escape_html": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 825.198974609375,
|
||||
"y": 123.75
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"output_links": [],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
},
|
||||
{
|
||||
"id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"block_id": "b1ab9b19-67a6-406d-abf5-2dba76d00c79",
|
||||
"input_default": {
|
||||
"b": 34,
|
||||
"operation": "Add",
|
||||
"round_result": false
|
||||
},
|
||||
"metadata": {
|
||||
"position": {
|
||||
"x": 323.0255126953125,
|
||||
"y": 121.25
|
||||
}
|
||||
},
|
||||
"input_links": [
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"output_links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
}
|
||||
],
|
||||
"graph_id": "858e2226-e047-4d19-a832-3be4a134d155",
|
||||
"graph_version": 2,
|
||||
"webhook_id": null
|
||||
}
|
||||
],
|
||||
"links": [
|
||||
{
|
||||
"id": "8cdb2f33-5b10-4cc2-8839-f8ccb70083a3",
|
||||
"source_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"sink_id": "65429c9e-a0c6-4032-a421-6899c394fa74",
|
||||
"source_name": "result",
|
||||
"sink_name": "value",
|
||||
"is_static": false
|
||||
},
|
||||
{
|
||||
"id": "432c7caa-49b9-4b70-bd21-2fa33a569601",
|
||||
"source_id": "6762da5d-6915-4836-a431-6dcd7d36a54a",
|
||||
"sink_id": "bf4a15ff-b0c4-4032-a21b-5880224af690",
|
||||
"source_name": "result",
|
||||
"sink_name": "a",
|
||||
"is_static": true
|
||||
}
|
||||
],
|
||||
"sub_graphs": [],
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Input": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Input"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Input"
|
||||
]
|
||||
},
|
||||
"output_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Output": {
|
||||
"advanced": false,
|
||||
"secret": false,
|
||||
"title": "Output"
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"Output"
|
||||
]
|
||||
},
|
||||
"has_external_trigger": false,
|
||||
"has_human_in_the_loop": false,
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"credentials_input_schema": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ from backend.data.model import (
|
||||
Credentials,
|
||||
CredentialsFieldInfo,
|
||||
CredentialsMetaInput,
|
||||
NodeExecutionStats,
|
||||
SchemaField,
|
||||
is_credentials_field_name,
|
||||
)
|
||||
@@ -43,7 +44,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import ContributorDetails, NodeExecutionStats
|
||||
from backend.data.model import ContributorDetails
|
||||
|
||||
from ..data.graph import Link
|
||||
|
||||
@@ -420,6 +421,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Return extra runtime cost to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by charge_usage
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
@@ -455,8 +469,6 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
disabled: If the block is disabled, it will not be available for execution.
|
||||
static_output: Whether the output links of the block are static by default.
|
||||
"""
|
||||
from backend.data.model import NodeExecutionStats
|
||||
|
||||
self.id = id
|
||||
self.input_schema = input_schema
|
||||
self.output_schema = output_schema
|
||||
@@ -474,7 +486,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.is_sensitive_action = is_sensitive_action
|
||||
# Read from ClassVar set by initialize_blocks()
|
||||
self.optimized_description: str | None = type(self)._optimized_description
|
||||
self.execution_stats: "NodeExecutionStats" = NodeExecutionStats()
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -554,7 +566,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: "NodeExecutionStats") -> "NodeExecutionStats":
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
self.execution_stats += stats
|
||||
return self.execution_stats
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
from backend.util.security import SENSITIVE_FIELD_NAMES
|
||||
from backend.util.tool_call_loop import (
|
||||
@@ -364,10 +365,31 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
class OrchestratorBlock(Block):
|
||||
"""A block that uses a language model to orchestrate tool calls.
|
||||
|
||||
Supports both single-shot and iterative agent mode execution.
|
||||
|
||||
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
|
||||
(IBE) must always re-raise through every ``except`` block in this class.
|
||||
Swallowing IBE would let the agent loop continue with unpaid work. Every
|
||||
exception handler that catches ``Exception`` includes an explicit IBE
|
||||
re-raise carve-out for this reason.
|
||||
"""
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
"""
|
||||
|
||||
def extra_runtime_cost(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Charge one extra runtime cost per LLM call beyond the first.
|
||||
|
||||
In agent mode each iteration makes one LLM call. The first is already
|
||||
covered by charge_usage(); this returns the number of additional
|
||||
credits so the executor can bill the remaining calls post-completion.
|
||||
|
||||
SDK-mode exemption: when the block runs via _execute_tools_sdk_mode,
|
||||
the SDK manages its own conversation loop and only exposes aggregate
|
||||
usage. We hardcode llm_call_count=1 there (the SDK does not report a
|
||||
per-turn call count), so this method always returns 0 for SDK-mode
|
||||
executions. Per-iteration billing does not apply to SDK mode.
|
||||
"""
|
||||
return max(0, execution_stats.llm_call_count - 1)
|
||||
|
||||
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
|
||||
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
|
||||
@@ -1077,7 +1099,10 @@ class OrchestratorBlock(Block):
|
||||
input_data=input_value,
|
||||
)
|
||||
|
||||
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||
if node_exec_result is None:
|
||||
raise RuntimeError(
|
||||
f"upsert_execution_input returned None for node {sink_node_id}"
|
||||
)
|
||||
|
||||
# Create NodeExecutionEntry for execution manager
|
||||
node_exec_entry = NodeExecutionEntry(
|
||||
@@ -1112,15 +1137,86 @@ class OrchestratorBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
# Execute the node directly since we're in the Orchestrator context.
|
||||
# Wrap in try/except so the future is always resolved, even on
|
||||
# error — an unresolved Future would block anything awaiting it.
|
||||
#
|
||||
# on_node_execution is decorated with @async_error_logged(swallow=True),
|
||||
# which catches BaseException and returns None rather than raising.
|
||||
# Treat a None return as a failure: set_exception so the future
|
||||
# carries an error state rather than a None result, and return an
|
||||
# error response so the LLM knows the tool failed.
|
||||
try:
|
||||
tool_node_stats = await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
)
|
||||
if tool_node_stats is None:
|
||||
nil_err = RuntimeError(
|
||||
f"on_node_execution returned None for node {sink_node_id} "
|
||||
"(error was swallowed by @async_error_logged)"
|
||||
)
|
||||
node_exec_future.set_exception(nil_err)
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
"Tool execution returned no result",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
node_exec_future.set_result(tool_node_stats)
|
||||
except Exception as exec_err:
|
||||
node_exec_future.set_exception(exec_err)
|
||||
raise
|
||||
|
||||
# Charge user credits AFTER successful tool execution. Tools
|
||||
# spawned by the orchestrator bypass the main execution queue
|
||||
# (where _charge_usage is called), so we must charge here to
|
||||
# avoid free tool execution. Charging post-completion (vs.
|
||||
# pre-execution) avoids billing users for failed tool calls.
|
||||
# Skipped for dry runs.
|
||||
#
|
||||
# `error is None` intentionally excludes both Exception and
|
||||
# BaseException subclasses (e.g. CancelledError) so cancelled
|
||||
# or terminated tool runs are not billed.
|
||||
#
|
||||
# Billing errors (including non-balance exceptions) are kept
|
||||
# in a separate try/except so they are never silently swallowed
|
||||
# by the generic tool-error handler below.
|
||||
if (
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats.error is None
|
||||
):
|
||||
try:
|
||||
tool_cost, _ = await execution_processor.charge_node_usage(
|
||||
node_exec_entry,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see OrchestratorBlock class docstring.
|
||||
# Log the billing failure here so the discarded tool result
|
||||
# is traceable before the loop aborts.
|
||||
logger.warning(
|
||||
"Insufficient balance charging for tool node %s after "
|
||||
"successful execution; agent loop will be aborted",
|
||||
sink_node_id,
|
||||
)
|
||||
raise
|
||||
except Exception:
|
||||
# Non-billing charge failures (DB outage, network, etc.)
|
||||
# must NOT propagate to the outer except handler because
|
||||
# the tool itself succeeded. Re-raising would mark the
|
||||
# tool as failed (_is_error=True), causing the LLM to
|
||||
# retry side-effectful operations. Log and continue.
|
||||
logger.exception(
|
||||
"Unexpected error charging for tool node %s; "
|
||||
"tool execution was successful",
|
||||
sink_node_id,
|
||||
)
|
||||
tool_cost = 0
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
@@ -1133,18 +1229,26 @@ class OrchestratorBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(
|
||||
resp = _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
resp["_is_error"] = False
|
||||
return resp
|
||||
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Tool execution with manager failed: %s", e)
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
|
||||
# Return a generic error to the LLM — internal exception messages
|
||||
# may contain server paths, DB details, or infrastructure info.
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {e}",
|
||||
"Tool execution failed due to an internal error",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
|
||||
async def _agent_mode_llm_caller(
|
||||
self,
|
||||
@@ -1244,13 +1348,16 @@ class OrchestratorBlock(Block):
|
||||
content = str(raw_content)
|
||||
else:
|
||||
content = "Tool executed successfully"
|
||||
tool_failed = content.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
content=content,
|
||||
is_error=tool_failed,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
return ToolCallResult(
|
||||
@@ -1370,9 +1477,13 @@ class OrchestratorBlock(Block):
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all errors (validation, network, API) so that the block
|
||||
# surfaces them as user-visible output instead of crashing.
|
||||
# Catch all OTHER errors (validation, network, API) so that
|
||||
# the block surfaces them as user-visible output instead of
|
||||
# crashing.
|
||||
yield "error", str(e)
|
||||
return
|
||||
|
||||
@@ -1450,11 +1561,14 @@ class OrchestratorBlock(Block):
|
||||
text = content
|
||||
else:
|
||||
text = json.dumps(content)
|
||||
tool_failed = text.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": tool_failed,
|
||||
}
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("SDK tool execution failed: %s", e)
|
||||
return {
|
||||
@@ -1733,11 +1847,15 @@ class OrchestratorBlock(Block):
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
pass
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring. The `finally`
|
||||
# block below still runs and records partial token usage.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Surface SDK errors as user-visible output instead of crashing,
|
||||
# consistent with _execute_tools_agent_mode error handling.
|
||||
# Don't return yet — fall through to merge_stats below so
|
||||
# partial token usage is always recorded.
|
||||
# Surface OTHER SDK errors as user-visible output instead
|
||||
# of crashing, consistent with _execute_tools_agent_mode
|
||||
# error handling. Don't return yet — fall through to
|
||||
# merge_stats below so partial token usage is always recorded.
|
||||
sdk_error = e
|
||||
finally:
|
||||
# Always record usage stats, even on error. The SDK may have
|
||||
|
||||
@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Must be AsyncMock because it is
|
||||
# an async method and is directly awaited in _execute_single_tool_with_manager.
|
||||
# Use a non-zero cost so the merge_stats branch is exercised.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
# Verify charge_node_usage was actually called for the successful
|
||||
# tool execution — this guards against regressions where the
|
||||
# post-execution tool charging is accidentally removed.
|
||||
assert mock_execution_processor.charge_node_usage.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
|
||||
@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would
|
||||
# return a non-awaitable tuple and TypeError out, then be
|
||||
# silently swallowed by the orchestrator's catch-all.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(
|
||||
return_value=(0, 0)
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
ep.execution_stats_lock = threading.Lock()
|
||||
ns = MagicMock(error=None)
|
||||
ep.on_node_execution = AsyncMock(return_value=ns)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would return a
|
||||
# non-awaitable tuple and TypeError out, then be silently swallowed by
|
||||
# the orchestrator's catch-all.
|
||||
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
|
||||
@@ -197,6 +197,15 @@ class ChatConfig(BaseSettings):
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
claude_agent_cross_user_prompt_cache: bool = Field(
|
||||
default=True,
|
||||
description="Enable cross-user prompt caching via SystemPromptPreset. "
|
||||
"The Claude Code default prompt becomes a cacheable prefix shared "
|
||||
"across all users, and our custom prompt is appended after it. "
|
||||
"Dynamic sections (working dir, git status, auto-memory) are excluded "
|
||||
"from the prefix. Set to False to fall back to passing the system "
|
||||
"prompt as a raw string.",
|
||||
)
|
||||
claude_agent_cli_path: str | None = Field(
|
||||
default=None,
|
||||
description="Optional explicit path to a Claude Code CLI binary. "
|
||||
|
||||
@@ -0,0 +1,555 @@
|
||||
"""Tests for context fallback paths introduced in fix/copilot-transcript-resume-gate.
|
||||
|
||||
Scenario table
|
||||
==============
|
||||
|
||||
| # | use_resume | transcript_msg_count | gap | target_tokens | Expected output |
|
||||
|---|------------|----------------------|---------|---------------|--------------------------------------------|
|
||||
| A | True | covers all | empty | None | bare message (--resume has full context) |
|
||||
| B | True | stale | 2 msgs | None | gap context prepended |
|
||||
| C | True | stale | 2 msgs | 50_000 | gap compressed to budget, prepended |
|
||||
| D | False | 0 | N/A | None | full session compressed, prepended |
|
||||
| E | False | 0 | N/A | 50_000 | full session compressed to budget |
|
||||
| F | False | 2 (partial) | 2 msgs | None | full session compressed (not just gap; |
|
||||
| | | | | | CLI has zero context without --resume) |
|
||||
| G | False | 2 (partial) | 2 msgs | 50_000 | full session compressed to budget |
|
||||
| H | False | covers all | empty | None | full session compressed |
|
||||
| | | | | | (NOT bare message — the bug that was fixed)|
|
||||
| I | False | covers all | empty | 50_000 | full session compressed to tight budget |
|
||||
| J | False | 2 (partial) | n/a | None | exactly ONE compression call (full prior) |
|
||||
|
||||
Compression unit tests
|
||||
=======================
|
||||
|
||||
| # | Input | target_tokens | Expected |
|
||||
|---|----------------------|---------------|-----------------------------------------------|
|
||||
| K | [] | None | ([], False) — empty guard |
|
||||
| L | [1 msg] | None | ([msg], False) — single-msg guard |
|
||||
| M | [2+ msgs] | None | target_tokens=None forwarded to _run_compression |
|
||||
| N | [2+ msgs] | 30_000 | target_tokens=30_000 forwarded |
|
||||
| O | [2+ msgs], run fails | None | returns originals, False |
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import _build_query_message, _compress_messages
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(messages: list[ChatMessage]) -> ChatSession:
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="test-session",
|
||||
user_id="user-1",
|
||||
messages=messages,
|
||||
title="test",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
|
||||
|
||||
def _msgs(*pairs: tuple[str, str]) -> list[ChatMessage]:
|
||||
return [ChatMessage(role=r, content=c) for r, c in pairs]
|
||||
|
||||
|
||||
def _passthrough_compress(target_tokens=None):
|
||||
"""Return a mock that passes messages through and records its call args."""
|
||||
calls: list[tuple[list, int | None]] = []
|
||||
|
||||
async def _mock(msgs, tok=None):
|
||||
calls.append((msgs, tok))
|
||||
return msgs, False
|
||||
|
||||
_mock.calls = calls # type: ignore[attr-defined]
|
||||
return _mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _build_query_message — scenario A–J
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildQueryMessageResume:
|
||||
"""use_resume=True paths (--resume supplies history; only inject gap if stale)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_a_transcript_current_returns_bare_message(self):
|
||||
"""Scenario A: --resume covers full context → no prefix injected."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
result, compacted = await _build_query_message(
|
||||
"q2", session, use_resume=True, transcript_msg_count=2, session_id="s"
|
||||
)
|
||||
assert result == "q2"
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_b_stale_transcript_injects_gap(self, monkeypatch):
|
||||
"""Scenario B: stale transcript → gap context prepended."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, compacted = await _build_query_message(
|
||||
"q3", session, use_resume=True, transcript_msg_count=2, session_id="s"
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "q2" in result
|
||||
assert "a2" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
# q1/a1 are covered by the transcript — must NOT appear in gap context
|
||||
assert "q1" not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_c_stale_transcript_passes_target_tokens(self, monkeypatch):
|
||||
"""Scenario C: target_tokens is forwarded to _compress_messages for the gap."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=50_000,
|
||||
)
|
||||
assert captured == [50_000]
|
||||
|
||||
|
||||
class TestBuildQueryMessageNoResumeNoTranscript:
|
||||
"""use_resume=False, transcript_msg_count=0 — full session compressed."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_d_full_session_compressed(self, monkeypatch):
|
||||
"""Scenario D: no resume, no transcript → compress all prior messages."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, compacted = await _build_query_message(
|
||||
"q2", session, use_resume=False, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
assert "q1" in result
|
||||
assert "a1" in result
|
||||
assert "Now, the user says:\nq2" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_e_passes_target_tokens_to_compression(self, monkeypatch):
|
||||
"""Scenario E: target_tokens forwarded to _compress_messages."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q2",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="s",
|
||||
target_tokens=15_000,
|
||||
)
|
||||
assert captured == [15_000]
|
||||
|
||||
|
||||
class TestBuildQueryMessageNoResumeWithTranscript:
|
||||
"""use_resume=False, transcript_msg_count > 0 — gap or full-session fallback."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_f_no_resume_always_injects_full_session(self, monkeypatch):
|
||||
"""Scenario F: use_resume=False with transcript_msg_count > 0 still injects
|
||||
the FULL prior session — not just the gap since the transcript end.
|
||||
|
||||
When there is no --resume the CLI starts with zero context, so injecting
|
||||
only the post-transcript gap would silently drop all transcript-covered
|
||||
history. The correct fix is to always compress the full session.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"), # transcript_msg_count=2 covers these
|
||||
("assistant", "a1"),
|
||||
("user", "q2"), # post-transcript gap starts here
|
||||
("assistant", "a2"),
|
||||
("user", "q3"), # current message
|
||||
)
|
||||
)
|
||||
compressed_msgs: list[list] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
compressed_msgs.append(list(msgs))
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, _ = await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2, # transcript covers q1/a1 but no --resume
|
||||
session_id="s",
|
||||
)
|
||||
assert "<conversation_history>" in result
|
||||
# Full session must be injected — transcript-covered turns ARE included
|
||||
assert "q1" in result
|
||||
assert "a1" in result
|
||||
assert "q2" in result
|
||||
assert "a2" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
# Compressed exactly once with all 4 prior messages
|
||||
assert len(compressed_msgs) == 1
|
||||
assert len(compressed_msgs[0]) == 4
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_g_no_resume_passes_target_tokens(self, monkeypatch):
|
||||
"""Scenario G: target_tokens forwarded when use_resume=False + transcript_msg_count > 0."""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=50_000,
|
||||
)
|
||||
assert captured == [50_000]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_h_no_resume_transcript_current_injects_full_session(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Scenario H: the bug that was fixed.
|
||||
|
||||
Old code path: use_resume=False, transcript_msg_count covers all prior
|
||||
messages → gap sub-path: gap = [] → ``return current_message, False``
|
||||
→ model received ZERO context (bare message only).
|
||||
|
||||
New code path: use_resume=False always compresses the full prior session
|
||||
regardless of transcript_msg_count — model always gets context.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
result, _ = await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=4, # covers ALL prior → old code returned bare msg
|
||||
session_id="s",
|
||||
)
|
||||
# NEW: must inject full session, NOT return bare message
|
||||
assert result != "q3"
|
||||
assert "<conversation_history>" in result
|
||||
assert "q1" in result
|
||||
assert "Now, the user says:\nq3" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_i_no_resume_target_tokens_forwarded_any_transcript_count(
|
||||
self, monkeypatch
|
||||
):
|
||||
"""Scenario I: target_tokens forwarded even when transcript_msg_count covers all."""
|
||||
session = _make_session(
|
||||
_msgs(("user", "q1"), ("assistant", "a1"), ("user", "q2"))
|
||||
)
|
||||
captured: list[int | None] = []
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
captured.append(target_tokens)
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q2",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
target_tokens=15_000,
|
||||
)
|
||||
assert 15_000 in captured
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_j_no_resume_single_compression_call(self, monkeypatch):
|
||||
"""Scenario J: use_resume=False always makes exactly ONE compression call
|
||||
(the full session), regardless of transcript coverage.
|
||||
|
||||
This verifies there is no two-step gap+fallback pattern for no-resume —
|
||||
compression is called once with the full prior session.
|
||||
"""
|
||||
session = _make_session(
|
||||
_msgs(
|
||||
("user", "q1"),
|
||||
("assistant", "a1"),
|
||||
("user", "q2"),
|
||||
("assistant", "a2"),
|
||||
("user", "q3"),
|
||||
)
|
||||
)
|
||||
call_count = 0
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages", _mock_compress
|
||||
)
|
||||
|
||||
await _build_query_message(
|
||||
"q3",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=2,
|
||||
session_id="s",
|
||||
)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _compress_messages — unit tests K–O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompressMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_k_empty_list_returns_empty(self):
|
||||
"""Scenario K: empty input → short-circuit, no compression."""
|
||||
result, compacted = await _compress_messages([])
|
||||
assert result == []
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_l_single_message_returns_as_is(self):
|
||||
"""Scenario L: single message → short-circuit (< 2 guard)."""
|
||||
msg = ChatMessage(role="user", content="hello")
|
||||
result, compacted = await _compress_messages([msg])
|
||||
assert result == [msg]
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_m_target_tokens_none_forwarded(self):
|
||||
"""Scenario M: target_tokens=None forwarded to _run_compression."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
fake_result = CompressResult(
|
||||
messages=[
|
||||
{"role": "user", "content": "q"},
|
||||
{"role": "assistant", "content": "a"},
|
||||
],
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
original_token_count=10,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_result,
|
||||
) as mock_run:
|
||||
await _compress_messages(msgs, target_tokens=None)
|
||||
|
||||
mock_run.assert_awaited_once()
|
||||
_, kwargs = mock_run.call_args
|
||||
assert kwargs.get("target_tokens") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_n_explicit_target_tokens_forwarded(self):
|
||||
"""Scenario N: explicit target_tokens forwarded to _run_compression."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
fake_result = CompressResult(
|
||||
messages=[{"role": "user", "content": "summary"}],
|
||||
token_count=5,
|
||||
was_compacted=True,
|
||||
original_token_count=50,
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
return_value=fake_result,
|
||||
) as mock_run:
|
||||
result, compacted = await _compress_messages(msgs, target_tokens=30_000)
|
||||
|
||||
mock_run.assert_awaited_once()
|
||||
_, kwargs = mock_run.call_args
|
||||
assert kwargs.get("target_tokens") == 30_000
|
||||
assert compacted is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scenario_o_run_compression_exception_returns_originals(self):
|
||||
"""Scenario O: _run_compression raises → return original messages, False."""
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("compression timeout"),
|
||||
):
|
||||
result, compacted = await _compress_messages(msgs)
|
||||
|
||||
assert result == msgs
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_messages_filtered_before_compression(self):
|
||||
"""filter_compaction_messages is applied before _run_compression is called."""
|
||||
# A compaction message is one with role=assistant and specific content pattern.
|
||||
# We verify that only real messages reach _run_compression.
|
||||
from backend.copilot.sdk.service import filter_compaction_messages
|
||||
|
||||
msgs = [
|
||||
ChatMessage(role="user", content="q"),
|
||||
ChatMessage(role="assistant", content="a"),
|
||||
]
|
||||
# filter_compaction_messages should not remove these plain messages
|
||||
filtered = filter_compaction_messages(msgs)
|
||||
assert len(filtered) == len(msgs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# target_tokens threading — _retry_target_tokens values match expectations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetryTargetTokens:
|
||||
def test_first_retry_uses_first_slot(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[0] == 50_000
|
||||
|
||||
def test_second_retry_uses_second_slot(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[1] == 15_000
|
||||
|
||||
def test_second_slot_smaller_than_first(self):
|
||||
from backend.copilot.sdk.service import _RETRY_TARGET_TOKENS
|
||||
|
||||
assert _RETRY_TARGET_TOKENS[1] < _RETRY_TARGET_TOKENS[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-message session edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSingleMessageSessions:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_resume_single_message_returns_bare(self):
|
||||
"""First turn (1 message): no prior history to inject."""
|
||||
session = _make_session([ChatMessage(role="user", content="hello")])
|
||||
result, compacted = await _build_query_message(
|
||||
"hello", session, use_resume=False, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert result == "hello"
|
||||
assert compacted is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_single_message_returns_bare(self):
|
||||
"""First turn with resume flag: transcript is empty so no gap."""
|
||||
session = _make_session([ChatMessage(role="user", content="hello")])
|
||||
result, compacted = await _build_query_message(
|
||||
"hello", session, use_resume=True, transcript_msg_count=0, session_id="s"
|
||||
)
|
||||
assert result == "hello"
|
||||
assert compacted is False
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
|
||||
from backend.copilot.model import ChatMessage, ChatSession
|
||||
from backend.copilot.sdk.service import (
|
||||
_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
_build_query_message,
|
||||
_format_conversation_context,
|
||||
)
|
||||
@@ -130,6 +131,34 @@ async def test_build_query_resume_up_to_date():
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_misaligned_watermark():
|
||||
"""With --resume and watermark pointing at a user message, skip gap."""
|
||||
# Simulates a deleted message shifting DB positions so the watermark
|
||||
# lands on a user turn instead of the expected assistant turn.
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="turn 1"),
|
||||
ChatMessage(role="assistant", content="reply 1"),
|
||||
ChatMessage(
|
||||
role="user", content="turn 2"
|
||||
), # ← watermark points here (role=user)
|
||||
ChatMessage(role="assistant", content="reply 2"),
|
||||
ChatMessage(role="user", content="turn 3"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"turn 3",
|
||||
session,
|
||||
use_resume=True,
|
||||
transcript_msg_count=3, # prior[2].role == "user" — misaligned
|
||||
session_id="test-session",
|
||||
)
|
||||
# Misaligned watermark → skip gap, return bare message
|
||||
assert result == "turn 3"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_resume_stale_transcript():
|
||||
"""With --resume and stale transcript, gap context is prepended."""
|
||||
@@ -204,7 +233,7 @@ async def test_build_query_no_resume_multi_message(monkeypatch):
|
||||
)
|
||||
|
||||
# Mock _compress_messages to return the messages as-is
|
||||
async def _mock_compress(msgs):
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -237,7 +266,7 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs):
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, True # Simulate actual compaction
|
||||
|
||||
monkeypatch.setattr(
|
||||
@@ -253,3 +282,85 @@ async def test_build_query_no_resume_multi_message_compacted(monkeypatch):
|
||||
session_id="test-session",
|
||||
)
|
||||
assert was_compacted is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_at_token_floor():
|
||||
"""When target_tokens is at or below the floor, return bare message.
|
||||
|
||||
This is the final escape hatch: if the retry budget is exhausted and
|
||||
even the most aggressive compression might not fit, skip history
|
||||
injection entirely so the user always gets a response.
|
||||
"""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old question"),
|
||||
ChatMessage(role="assistant", content="old answer"),
|
||||
ChatMessage(role="user", content="new question"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new question",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
)
|
||||
# At the floor threshold, no history is injected
|
||||
assert result == "new question"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_below_token_floor():
|
||||
"""target_tokens strictly below floor also returns bare message."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR - 1,
|
||||
)
|
||||
assert result == "new"
|
||||
assert was_compacted is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_query_no_resume_above_token_floor_compresses(monkeypatch):
|
||||
"""target_tokens just above the floor still triggers compression."""
|
||||
session = _make_session(
|
||||
[
|
||||
ChatMessage(role="user", content="old"),
|
||||
ChatMessage(role="assistant", content="reply"),
|
||||
ChatMessage(role="user", content="new"),
|
||||
]
|
||||
)
|
||||
|
||||
async def _mock_compress(msgs, target_tokens=None):
|
||||
return msgs, False
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.sdk.service._compress_messages",
|
||||
_mock_compress,
|
||||
)
|
||||
|
||||
result, was_compacted = await _build_query_message(
|
||||
"new",
|
||||
session,
|
||||
use_resume=False,
|
||||
transcript_msg_count=0,
|
||||
session_id="test-session",
|
||||
target_tokens=_BARE_MESSAGE_TOKEN_FLOOR + 1,
|
||||
)
|
||||
# Above the floor → history is injected (not the bare message)
|
||||
assert "<conversation_history>" in result
|
||||
assert "Now, the user says:\nnew" in result
|
||||
|
||||
@@ -7,6 +7,7 @@ tests will catch it immediately.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -90,6 +91,39 @@ def test_agent_options_accepts_required_fields():
|
||||
assert opts.cwd == "/tmp"
|
||||
|
||||
|
||||
def test_agent_options_accepts_system_prompt_preset_with_exclude_dynamic_sections():
|
||||
"""Verify ClaudeAgentOptions accepts the exact preset dict _build_system_prompt_value produces.
|
||||
|
||||
The production code always includes ``exclude_dynamic_sections=True`` in the preset
|
||||
dict. This compat test mirrors that exact shape so any SDK version that starts
|
||||
rejecting unknown keys will be caught here rather than at runtime.
|
||||
"""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
from claude_agent_sdk.types import SystemPromptPreset
|
||||
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
# Call the production helper directly so this test is tied to the real
|
||||
# dict shape rather than a hand-rolled copy.
|
||||
preset = _build_system_prompt_value("custom system prompt", cross_user_cache=True)
|
||||
assert isinstance(
|
||||
preset, dict
|
||||
), "_build_system_prompt_value must return a dict when caching is on"
|
||||
|
||||
sdk_preset = cast(SystemPromptPreset, preset)
|
||||
opts = ClaudeAgentOptions(system_prompt=sdk_preset)
|
||||
assert opts.system_prompt == sdk_preset
|
||||
|
||||
|
||||
def test_build_system_prompt_value_returns_plain_string_when_cross_user_cache_off():
|
||||
"""When cross_user_cache=False (e.g. on --resume turns), the helper must return
|
||||
a plain string so the preset+resume crash is avoided."""
|
||||
from .service import _build_system_prompt_value
|
||||
|
||||
result = _build_system_prompt_value("my prompt", cross_user_cache=False)
|
||||
assert result == "my prompt", "Must return the raw string, not a preset dict"
|
||||
|
||||
|
||||
def test_agent_options_accepts_all_our_fields():
|
||||
"""Comprehensive check of every field we use in service.py."""
|
||||
from claude_agent_sdk import ClaudeAgentOptions
|
||||
|
||||
@@ -29,6 +29,7 @@ from claude_agent_sdk import (
|
||||
ToolResultBlock,
|
||||
ToolUseBlock,
|
||||
)
|
||||
from claude_agent_sdk.types import SystemPromptPreset
|
||||
from langfuse import propagate_attributes
|
||||
from langsmith.integrations.claude_agent_sdk import configure_claude_agent_sdk
|
||||
from opentelemetry import trace as otel_trace
|
||||
@@ -260,6 +261,11 @@ class ReducedContext(NamedTuple):
|
||||
resume_file: str | None
|
||||
transcript_lost: bool
|
||||
tried_compaction: bool
|
||||
# Token budget for history compression on the DB-message fallback path.
|
||||
# None means "use model-aware default". Halved on each retry so
|
||||
# compress_context applies progressively more aggressive reduction
|
||||
# (LLM summarize → content truncate → middle-out delete → first/last trim).
|
||||
target_tokens: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -304,6 +310,10 @@ class _RetryState:
|
||||
adapter: SDKResponseAdapter
|
||||
transcript_builder: TranscriptBuilder
|
||||
usage: _TokenUsage
|
||||
# Token budget for history compression on retries (DB-message fallback path).
|
||||
# None = model-aware default. Halved each retry for progressively more
|
||||
# aggressive compression (LLM summarize → truncate → middle-out → trim).
|
||||
target_tokens: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -335,12 +345,34 @@ class _StreamContext:
|
||||
lock: AsyncClusterLock
|
||||
|
||||
|
||||
# Per-retry token budgets for the no-transcript (use_resume=False) path.
|
||||
# When there is no CLI native session to --resume, context is built from DB
|
||||
# messages via _format_conversation_context. For large sessions this text
|
||||
# can exceed the model context window; each retry halves the token budget so
|
||||
# compress_context applies progressively more aggressive reduction:
|
||||
# LLM summarize → content truncate → middle-out delete → first/last trim.
|
||||
# Index 0 = first retry, 1 = second retry; last value applies beyond that.
|
||||
_RETRY_TARGET_TOKENS: tuple[int, ...] = (50_000, 15_000)
|
||||
|
||||
# Below this token budget the model context is so tight that injecting any
|
||||
# conversation history would likely exceed the limit regardless of content.
|
||||
# _build_query_message returns the bare message when target_tokens falls to
|
||||
# or below this floor, giving the user a response instead of a hard error.
|
||||
_BARE_MESSAGE_TOKEN_FLOOR: int = 5_000
|
||||
|
||||
# Tight token budget for seeding the transcript builder on turns where no
|
||||
# CLI native session exists. Kept below _RETRY_TARGET_TOKENS[0] so the
|
||||
# seeded JSONL upload stays compact and future gap injections are small.
|
||||
_SEED_TARGET_TOKENS: int = 30_000
|
||||
|
||||
|
||||
async def _reduce_context(
|
||||
transcript_content: str,
|
||||
tried_compaction: bool,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str,
|
||||
attempt: int = 1,
|
||||
) -> ReducedContext:
|
||||
"""Prepare reduced context for a retry attempt.
|
||||
|
||||
@@ -348,9 +380,19 @@ async def _reduce_context(
|
||||
On subsequent retries (or if compaction fails), drops the transcript
|
||||
entirely so the query is rebuilt from DB messages only.
|
||||
|
||||
`transcript_lost` is True when the transcript was dropped (caller
|
||||
should set `skip_transcript_upload`).
|
||||
When no transcript is available (use_resume=False fallback path), returns
|
||||
a decreasing ``target_tokens`` budget so ``compress_context`` applies
|
||||
progressively more aggressive reduction (LLM summarize → content truncate
|
||||
→ middle-out delete → first/last trim). The budget applies in
|
||||
``_build_query_message`` and is halved on each retry.
|
||||
|
||||
``transcript_lost`` is True when the transcript was dropped (caller
|
||||
should set ``skip_transcript_upload``).
|
||||
"""
|
||||
# Token budget for the DB fallback on this attempt (no-transcript path).
|
||||
idx = max(0, attempt - 1)
|
||||
retry_target = _RETRY_TARGET_TOKENS[min(idx, len(_RETRY_TARGET_TOKENS) - 1)]
|
||||
|
||||
# First retry: try compacting our transcript builder state.
|
||||
# Note: the CLI native --resume file is not updated with the compacted
|
||||
# content (it would require emitting CLI-native JSONL format), so the
|
||||
@@ -374,9 +416,14 @@ async def _reduce_context(
|
||||
return ReducedContext(tb, False, None, False, True)
|
||||
logger.warning("%s Compaction failed, dropping transcript", log_prefix)
|
||||
|
||||
# Subsequent retry or compaction failed: drop transcript entirely
|
||||
logger.warning("%s Dropping transcript, rebuilding from DB messages", log_prefix)
|
||||
return ReducedContext(TranscriptBuilder(), False, None, True, True)
|
||||
# Subsequent retry or compaction failed: drop transcript entirely.
|
||||
# Return retry_target so the caller compresses DB messages to that budget.
|
||||
logger.warning(
|
||||
"%s Dropping transcript, rebuilding from DB messages" " (target_tokens=%d)",
|
||||
log_prefix,
|
||||
retry_target,
|
||||
)
|
||||
return ReducedContext(TranscriptBuilder(), False, None, True, True, retry_target)
|
||||
|
||||
|
||||
def _append_error_marker(
|
||||
@@ -705,6 +752,34 @@ def _is_fallback_stderr(line: str) -> bool:
|
||||
return "fallback model" in line.lower()
|
||||
|
||||
|
||||
def _build_system_prompt_value(
|
||||
system_prompt: str,
|
||||
cross_user_cache: bool,
|
||||
) -> str | SystemPromptPreset:
|
||||
"""Build the ``system_prompt`` argument for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *cross_user_cache* is enabled, returns a :class:`SystemPromptPreset`
|
||||
dict so the Claude Code default prompt becomes a cacheable prefix shared
|
||||
across all users; our custom *system_prompt* is appended after it.
|
||||
|
||||
When disabled (or if the SDK is too old to support ``SystemPromptPreset``),
|
||||
the raw *system_prompt* string is returned unchanged.
|
||||
|
||||
An empty *system_prompt* is accepted: the preset dict will have
|
||||
``append: ""`` which the SDK treats as no custom suffix.
|
||||
"""
|
||||
if cross_user_cache:
|
||||
logger.debug("Using SystemPromptPreset for cross-user prompt cache")
|
||||
return SystemPromptPreset(
|
||||
type="preset",
|
||||
preset="claude_code",
|
||||
append=system_prompt,
|
||||
exclude_dynamic_sections=True,
|
||||
)
|
||||
logger.debug("Cross-user prompt cache disabled, using raw string")
|
||||
return system_prompt
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
"""Create a safe, session-specific working directory path.
|
||||
|
||||
@@ -801,6 +876,7 @@ def _format_sdk_content_blocks(blocks: list) -> list[dict[str, Any]]:
|
||||
|
||||
async def _compress_messages(
|
||||
messages: list[ChatMessage],
|
||||
target_tokens: int | None = None,
|
||||
) -> tuple[list[ChatMessage], bool]:
|
||||
"""Compress a list of messages if they exceed the token threshold.
|
||||
|
||||
@@ -809,6 +885,10 @@ async def _compress_messages(
|
||||
`_compress_messages` and `compact_transcript` share this helper so
|
||||
client acquisition and error handling are consistent.
|
||||
|
||||
``target_tokens`` sets a hard ceiling for the compressed output so
|
||||
callers can enforce a tighter budget on retries. When ``None``,
|
||||
``compress_context`` uses the model-aware default.
|
||||
|
||||
See also:
|
||||
`_run_compression` — shared compression with timeout guards.
|
||||
`compact_transcript` — compresses JSONL transcript entries.
|
||||
@@ -832,7 +912,9 @@ async def _compress_messages(
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
try:
|
||||
result = await _run_compression(messages_dict, config.model, "[SDK]")
|
||||
result = await _run_compression(
|
||||
messages_dict, config.model, "[SDK]", target_tokens=target_tokens
|
||||
)
|
||||
except Exception as exc:
|
||||
# Guard against timeouts or unexpected errors in compression —
|
||||
# return the original messages so the caller can proceed without
|
||||
@@ -961,44 +1043,139 @@ async def _build_query_message(
|
||||
use_resume: bool,
|
||||
transcript_msg_count: int,
|
||||
session_id: str,
|
||||
target_tokens: int | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Build the query message with appropriate context.
|
||||
|
||||
When ``use_resume=True``, the CLI has the full session via ``--resume``;
|
||||
only a gap-fill prefix is injected when the transcript is stale.
|
||||
|
||||
When ``use_resume=False``, the CLI starts a fresh session with no prior
|
||||
context, so the full prior session is always compressed and injected via
|
||||
``_format_conversation_context``. ``compress_context`` handles size
|
||||
reduction internally (LLM summarize → content truncate → middle-out delete
|
||||
→ first/last trim). ``target_tokens`` decreases on each retry to force
|
||||
progressively more aggressive compression when the first attempt exceeds
|
||||
context limits.
|
||||
|
||||
Returns:
|
||||
Tuple of (query_message, was_compacted).
|
||||
"""
|
||||
msg_count = len(session.messages)
|
||||
prior = session.messages[:-1] # all turns except the current user message
|
||||
|
||||
logger.info(
|
||||
"[SDK] [%s] Context path: use_resume=%s, transcript_msg_count=%d,"
|
||||
" db_msg_count=%d, target_tokens=%s",
|
||||
session_id[:8],
|
||||
use_resume,
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
target_tokens,
|
||||
)
|
||||
|
||||
if use_resume and transcript_msg_count > 0:
|
||||
if transcript_msg_count < msg_count - 1:
|
||||
gap = session.messages[transcript_msg_count:-1]
|
||||
compressed, was_compressed = await _compress_messages(gap)
|
||||
# Sanity-check the watermark: the last covered position should be
|
||||
# an assistant turn. A user-role message here means the count is
|
||||
# misaligned (e.g. a message was deleted and DB positions shifted).
|
||||
# Skip the gap rather than injecting wrong context — the CLI session
|
||||
# loaded via --resume still has good history.
|
||||
if prior[transcript_msg_count - 1].role != "assistant":
|
||||
logger.warning(
|
||||
"[SDK] [%s] Watermark misaligned: prior[%d].role=%r"
|
||||
" (expected 'assistant') — skipping gap to avoid"
|
||||
" injecting wrong context (transcript=%d, db=%d)",
|
||||
session_id[:8],
|
||||
transcript_msg_count - 1,
|
||||
prior[transcript_msg_count - 1].role,
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
gap = prior[transcript_msg_count:]
|
||||
compressed, was_compressed = await _compress_messages(gap, target_tokens)
|
||||
gap_context = _format_conversation_context(compressed)
|
||||
if gap_context:
|
||||
logger.info(
|
||||
"[SDK] Transcript stale: covers %d of %d messages, "
|
||||
"gap=%d (compressed=%s)",
|
||||
"gap=%d (compressed=%s), gap_context_bytes=%d",
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
len(gap),
|
||||
was_compressed,
|
||||
len(gap_context),
|
||||
)
|
||||
return (
|
||||
f"{gap_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
logger.warning(
|
||||
"[SDK] [%s] Transcript stale: gap produced empty context"
|
||||
" (%d msgs, transcript=%d/%d) — sending message without gap prefix",
|
||||
session_id[:8],
|
||||
len(gap),
|
||||
transcript_msg_count,
|
||||
msg_count,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"[SDK] [%s] --resume covers full context (%d messages)",
|
||||
session_id[:8],
|
||||
transcript_msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
elif not use_resume and msg_count > 1:
|
||||
# No --resume: the CLI starts a fresh session with no prior context.
|
||||
# Injecting only the post-transcript gap would omit the transcript-covered
|
||||
# prefix entirely, so always compress the full prior session here.
|
||||
# compress_context handles size reduction internally (LLM summarize →
|
||||
# content truncate → middle-out delete → first/last trim).
|
||||
|
||||
# Final escape hatch: if the token budget is at or below the floor,
|
||||
# the model context is so tight that even fully compressed history
|
||||
# would risk a "prompt too long" error. Return the bare message so
|
||||
# the user always gets a response rather than a hard failure.
|
||||
if target_tokens is not None and target_tokens <= _BARE_MESSAGE_TOKEN_FLOOR:
|
||||
logger.warning(
|
||||
"[SDK] [%s] target_tokens=%d at or below floor (%d) —"
|
||||
" skipping history injection to guarantee response delivery"
|
||||
" (session has %d messages)",
|
||||
session_id[:8],
|
||||
target_tokens,
|
||||
_BARE_MESSAGE_TOKEN_FLOOR,
|
||||
msg_count,
|
||||
)
|
||||
return current_message, False
|
||||
|
||||
logger.warning(
|
||||
f"[SDK] Using compression fallback for session "
|
||||
f"{session_id} ({msg_count} messages) — no transcript for --resume"
|
||||
"[SDK] [%s] No --resume for %d-message session — compressing"
|
||||
" full session history (pod affinity issue or first turn after"
|
||||
" restore failure); target_tokens=%s",
|
||||
session_id[:8],
|
||||
msg_count,
|
||||
target_tokens,
|
||||
)
|
||||
compressed, was_compressed = await _compress_messages(session.messages[:-1])
|
||||
compressed, was_compressed = await _compress_messages(prior, target_tokens)
|
||||
history_context = _format_conversation_context(compressed)
|
||||
if history_context:
|
||||
logger.info(
|
||||
"[SDK] [%s] Fallback context built: compressed=%s," " context_bytes=%d",
|
||||
session_id[:8],
|
||||
was_compressed,
|
||||
len(history_context),
|
||||
)
|
||||
return (
|
||||
f"{history_context}\n\nNow, the user says:\n{current_message}",
|
||||
was_compressed,
|
||||
)
|
||||
logger.warning(
|
||||
"[SDK] [%s] Fallback context empty after compression"
|
||||
" (%d messages) — sending message without history",
|
||||
session_id[:8],
|
||||
len(prior),
|
||||
)
|
||||
|
||||
return current_message, False
|
||||
|
||||
@@ -1927,6 +2104,48 @@ async def _run_stream_attempt(
|
||||
)
|
||||
|
||||
|
||||
async def _seed_transcript(
|
||||
session: ChatSession,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
transcript_covers_prefix: bool,
|
||||
transcript_msg_count: int,
|
||||
log_prefix: str,
|
||||
) -> tuple[str, bool, int]:
|
||||
"""Seed the transcript builder from compressed DB messages.
|
||||
|
||||
Called when ``use_resume=False`` and no prior transcript exists in storage
|
||||
so that ``upload_transcript`` saves a compact version for future turns.
|
||||
This ensures the next turn can use the full-session compression path with
|
||||
the benefit of an already-compressed baseline, and a restored CLI session
|
||||
on the next pod gets a usable compact base even for sessions that started
|
||||
on old pods.
|
||||
|
||||
Returns ``(transcript_content, transcript_covers_prefix, transcript_msg_count)``
|
||||
updated values — unchanged if seeding is not possible.
|
||||
"""
|
||||
if len(session.messages) <= 1:
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
_prior = session.messages[:-1]
|
||||
_comp, _ = await _compress_messages(_prior, _SEED_TARGET_TOKENS)
|
||||
if not _comp:
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
_seeded = _session_messages_to_transcript(_comp)
|
||||
if not _seeded or not validate_transcript(_seeded):
|
||||
return "", transcript_covers_prefix, transcript_msg_count
|
||||
|
||||
transcript_builder.load_previous(_seeded, log_prefix=log_prefix)
|
||||
logger.info(
|
||||
"%s Seeded transcript from %d compressed DB messages"
|
||||
" for next-turn upload (seed_target_tokens=%d)",
|
||||
log_prefix,
|
||||
len(_comp),
|
||||
_SEED_TARGET_TOKENS,
|
||||
)
|
||||
return _seeded, True, len(_prior)
|
||||
|
||||
|
||||
async def stream_chat_completion_sdk(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -2198,9 +2417,20 @@ async def stream_chat_completion_sdk(
|
||||
# Builder loaded but CLI native session not available.
|
||||
# --resume will not be used this turn; upload after turn
|
||||
# will seed the native session for the next turn.
|
||||
#
|
||||
# Still record transcript_msg_count so _build_query_message
|
||||
# can use the transcript-aware gap path (inject only new
|
||||
# messages since the transcript end) instead of compressing
|
||||
# the full DB history. This avoids prompt-too-long on
|
||||
# large sessions where the CLI session is temporarily
|
||||
# unavailable (e.g. mixed-version rolling deployment).
|
||||
transcript_msg_count = dl.message_count
|
||||
logger.info(
|
||||
"%s CLI session not restored — running without --resume this turn",
|
||||
"%s CLI session not restored — running without"
|
||||
" --resume this turn (transcript_msg_count=%d for"
|
||||
" gap-aware fallback)",
|
||||
log_prefix,
|
||||
transcript_msg_count,
|
||||
)
|
||||
else:
|
||||
logger.warning("%s Transcript downloaded but invalid", log_prefix)
|
||||
@@ -2295,8 +2525,19 @@ async def stream_chat_completion_sdk(
|
||||
sid,
|
||||
)
|
||||
|
||||
# Use SystemPromptPreset for cross-user prompt caching.
|
||||
# WORKAROUND: CLI 2.1.97 (sdk 0.1.58) exits code 1 when
|
||||
# excludeDynamicSections=True is in the initialize request AND
|
||||
# --resume is active. Disable the preset on resumed turns.
|
||||
# Turn 1 still gets the preset (no --resume).
|
||||
_cross_user = config.claude_agent_cross_user_prompt_cache and not use_resume
|
||||
system_prompt_value = _build_system_prompt_value(
|
||||
system_prompt,
|
||||
cross_user_cache=_cross_user,
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
"system_prompt": system_prompt_value,
|
||||
"mcp_servers": {"copilot": mcp_server},
|
||||
"allowed_tools": allowed,
|
||||
"disallowed_tools": disallowed,
|
||||
@@ -2425,6 +2666,22 @@ async def stream_chat_completion_sdk(
|
||||
if attachments.hint:
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
# When running without --resume and no prior transcript in storage,
|
||||
# seed the transcript builder from compressed DB messages so that
|
||||
# upload_transcript saves a compact version for future turns.
|
||||
if not use_resume and not transcript_content and not skip_transcript_upload:
|
||||
(
|
||||
transcript_content,
|
||||
transcript_covers_prefix,
|
||||
transcript_msg_count,
|
||||
) = await _seed_transcript(
|
||||
session,
|
||||
transcript_builder,
|
||||
transcript_covers_prefix,
|
||||
transcript_msg_count,
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
tried_compaction = False
|
||||
|
||||
# Build the per-request context carrier (shared across attempts).
|
||||
@@ -2507,12 +2764,14 @@ async def stream_chat_completion_sdk(
|
||||
session_id,
|
||||
sdk_cwd,
|
||||
log_prefix,
|
||||
attempt=attempt,
|
||||
)
|
||||
state.transcript_builder = ctx.builder
|
||||
state.use_resume = ctx.use_resume
|
||||
state.resume_file = ctx.resume_file
|
||||
tried_compaction = ctx.tried_compaction
|
||||
state.transcript_msg_count = 0
|
||||
state.target_tokens = ctx.target_tokens
|
||||
if ctx.transcript_lost:
|
||||
skip_transcript_upload = True
|
||||
|
||||
@@ -2530,9 +2789,18 @@ async def stream_chat_completion_sdk(
|
||||
# T2+ retry without --resume: do not pass --session-id.
|
||||
# The T1 session file already exists at that path; re-using
|
||||
# the same ID would fail with "Session ID already in use".
|
||||
# The upload guard skips T2+ no-resume turns anyway.
|
||||
sdk_options_kwargs_retry.pop("resume", None)
|
||||
sdk_options_kwargs_retry.pop("session_id", None)
|
||||
# Recompute system_prompt for retry — ctx.use_resume may have
|
||||
# changed (context reduction enabled --resume). CLI 2.1.97
|
||||
# crashes when excludeDynamicSections=True is combined with
|
||||
# --resume, so disable the cross-user preset on resumed turns.
|
||||
_cross_user_retry = (
|
||||
config.claude_agent_cross_user_prompt_cache and not ctx.use_resume
|
||||
)
|
||||
sdk_options_kwargs_retry["system_prompt"] = _build_system_prompt_value(
|
||||
system_prompt, cross_user_cache=_cross_user_retry
|
||||
)
|
||||
state.options = ClaudeAgentOptions(**sdk_options_kwargs_retry) # type: ignore[arg-type] # dynamic kwargs
|
||||
state.query_message, state.was_compacted = await _build_query_message(
|
||||
current_message,
|
||||
@@ -2540,6 +2808,7 @@ async def stream_chat_completion_sdk(
|
||||
state.use_resume,
|
||||
state.transcript_msg_count,
|
||||
session_id,
|
||||
target_tokens=state.target_tokens,
|
||||
)
|
||||
if attachments.hint:
|
||||
state.query_message = f"{state.query_message}\n\n{attachments.hint}"
|
||||
@@ -3025,6 +3294,21 @@ async def stream_chat_completion_sdk(
|
||||
# the shielded inner coroutine continues running to completion so the
|
||||
# upload is not lost. This is intentional and matches the pattern
|
||||
# used for upload_transcript immediately above.
|
||||
#
|
||||
# NOTE: upload is attempted regardless of state.use_resume — even when
|
||||
# this turn ran without --resume (restore failed or first T2+ on a new
|
||||
# pod), the T1 session file at the expected path may still be present
|
||||
# and should be re-uploaded so the next turn can resume from it.
|
||||
# upload_cli_session silently skips when the file is absent, so this is
|
||||
# always safe.
|
||||
#
|
||||
# Intentionally NOT gated on skip_transcript_upload: that flag is set
|
||||
# when our custom JSONL transcript is dropped (transcript_lost=True on
|
||||
# reduced-context retries) but the CLI's native session file is written
|
||||
# independently. Blocking CLI upload on transcript_lost would prevent
|
||||
# T1 prompt-too-long retries from uploading their valid session file,
|
||||
# breaking --resume on the next pod. The ended_with_stream_error gate
|
||||
# above already covers actual turn failures.
|
||||
if (
|
||||
config.claude_agent_use_resume
|
||||
and user_id
|
||||
@@ -3032,9 +3316,15 @@ async def stream_chat_completion_sdk(
|
||||
and session is not None
|
||||
and state is not None
|
||||
and not ended_with_stream_error
|
||||
and not skip_transcript_upload
|
||||
and (not has_history or state.use_resume)
|
||||
):
|
||||
logger.info(
|
||||
"%s Attempting CLI session upload"
|
||||
" (use_resume=%s, has_history=%s, skip_transcript=%s)",
|
||||
log_prefix,
|
||||
state.use_resume,
|
||||
has_history,
|
||||
skip_transcript_upload,
|
||||
)
|
||||
try:
|
||||
await asyncio.shield(
|
||||
upload_cli_session(
|
||||
|
||||
@@ -15,6 +15,7 @@ from claude_agent_sdk import AssistantMessage, TextBlock, ToolUseBlock
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import (
|
||||
_RETRY_TARGET_TOKENS,
|
||||
ReducedContext,
|
||||
_is_prompt_too_long,
|
||||
_is_tool_only_message,
|
||||
@@ -208,6 +209,24 @@ class TestReduceContext:
|
||||
|
||||
assert ctx.transcript_lost is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_returns_target_tokens_attempt_1(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=1)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_returns_target_tokens_attempt_2(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=2)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[1]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_clamps_attempt_beyond_limits(self) -> None:
|
||||
ctx = await _reduce_context("", False, "sess-1", "/tmp", "[t]", attempt=99)
|
||||
assert ctx.transcript_lost is True
|
||||
assert ctx.target_tokens == _RETRY_TARGET_TOKENS[-1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _iter_sdk_messages
|
||||
|
||||
@@ -8,7 +8,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
from .service import (
|
||||
_build_system_prompt_value,
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
@@ -397,6 +400,7 @@ _CONFIG_ENV_VARS = (
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
"CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE",
|
||||
)
|
||||
|
||||
|
||||
@@ -656,3 +660,62 @@ class TestSafeCloseSdkClient:
|
||||
client.__aexit__ = AsyncMock(side_effect=ValueError("invalid argument"))
|
||||
with pytest.raises(ValueError, match="invalid argument"):
|
||||
await _safe_close_sdk_client(client, "[test]")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SystemPromptPreset — cross-user prompt caching
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSystemPromptPreset:
|
||||
"""Tests for _build_system_prompt_value — cross-user prompt caching."""
|
||||
|
||||
def test_preset_dict_structure_when_enabled(self):
|
||||
"""When cross_user_cache is True, returns a _SystemPromptPreset dict."""
|
||||
custom_prompt = "You are a helpful assistant."
|
||||
result = _build_system_prompt_value(custom_prompt, cross_user_cache=True)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "preset"
|
||||
assert result["preset"] == "claude_code"
|
||||
assert result["append"] == custom_prompt
|
||||
assert result["exclude_dynamic_sections"] is True
|
||||
|
||||
def test_raw_string_when_disabled(self):
|
||||
"""When cross_user_cache is False, returns the raw string."""
|
||||
custom_prompt = "You are a helpful assistant."
|
||||
result = _build_system_prompt_value(custom_prompt, cross_user_cache=False)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == custom_prompt
|
||||
|
||||
def test_empty_string_with_cache_enabled(self):
|
||||
"""Empty system_prompt with cross_user_cache=True produces append=''."""
|
||||
result = _build_system_prompt_value("", cross_user_cache=True)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["type"] == "preset"
|
||||
assert result["preset"] == "claude_code"
|
||||
assert result["append"] == ""
|
||||
assert result["exclude_dynamic_sections"] is True
|
||||
|
||||
def test_default_config_is_enabled(self, _clean_config_env):
|
||||
"""The default value for claude_agent_cross_user_prompt_cache is True."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is True
|
||||
|
||||
def test_env_var_disables_cache(self, _clean_config_env, monkeypatch):
|
||||
"""CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE=false disables caching."""
|
||||
monkeypatch.setenv("CHAT_CLAUDE_AGENT_CROSS_USER_PROMPT_CACHE", "false")
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is False
|
||||
|
||||
@@ -960,7 +960,7 @@ class TestRunCompression:
|
||||
)
|
||||
call_count = [0]
|
||||
|
||||
async def _compress_side_effect(*, messages, model, client):
|
||||
async def _compress_side_effect(*, messages, model, client, target_tokens=None):
|
||||
call_count[0] += 1
|
||||
if client is not None:
|
||||
# Simulate a hang that exceeds the timeout
|
||||
|
||||
@@ -1179,6 +1179,7 @@ async def _run_compression(
|
||||
messages: list[dict],
|
||||
model: str,
|
||||
log_prefix: str,
|
||||
target_tokens: int | None = None,
|
||||
) -> CompressResult:
|
||||
"""Run LLM-based compression with truncation fallback.
|
||||
|
||||
@@ -1187,6 +1188,12 @@ async def _run_compression(
|
||||
truncation-based compression which drops older messages without
|
||||
summarization.
|
||||
|
||||
``target_tokens`` sets a hard token ceiling for the compressed output.
|
||||
When ``None``, ``compress_context`` derives the limit from the model's
|
||||
context window. Pass a smaller value on retries to force more aggressive
|
||||
compression — the compressor will LLM-summarize, content-truncate,
|
||||
middle-out delete, and first/last trim until the result fits.
|
||||
|
||||
A 60-second timeout prevents a hung LLM call from blocking the
|
||||
retry path indefinitely. The truncation fallback also has a
|
||||
30-second timeout to guard against slow tokenization on very large
|
||||
@@ -1196,18 +1203,27 @@ async def _run_compression(
|
||||
if client is None:
|
||||
logger.warning("%s No OpenAI client configured, using truncation", log_prefix)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
compress_context(
|
||||
messages=messages, model=model, client=None, target_tokens=target_tokens
|
||||
),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=client),
|
||||
compress_context(
|
||||
messages=messages,
|
||||
model=model,
|
||||
client=client,
|
||||
target_tokens=target_tokens,
|
||||
),
|
||||
timeout=_COMPACTION_TIMEOUT_SECONDS,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await asyncio.wait_for(
|
||||
compress_context(messages=messages, model=model, client=None),
|
||||
compress_context(
|
||||
messages=messages, model=model, client=None, target_tokens=target_tokens
|
||||
),
|
||||
timeout=_TRUNCATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
@@ -349,7 +349,7 @@ class UserCreditBase(ABC):
|
||||
CreditTransactionType.GRANT,
|
||||
CreditTransactionType.TOP_UP,
|
||||
]:
|
||||
from backend.executor.manager import (
|
||||
from backend.executor.billing import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
@@ -554,7 +554,7 @@ class UserCreditBase(ABC):
|
||||
in [CreditTransactionType.GRANT, CreditTransactionType.TOP_UP]
|
||||
):
|
||||
# Lazy import to avoid circular dependency with executor.manager
|
||||
from backend.executor.manager import (
|
||||
from backend.executor.billing import (
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
|
||||
|
||||
@@ -852,6 +852,7 @@ class NodeExecutionStats(BaseModel):
|
||||
output_token_count: int = 0
|
||||
cache_read_token_count: int = 0
|
||||
cache_creation_token_count: int = 0
|
||||
cost: int = 0
|
||||
extra_cost: int = 0
|
||||
extra_steps: int = 0
|
||||
provider_cost: float | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@ from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -142,6 +143,7 @@ class UserCostSummary(BaseModel):
|
||||
total_cache_read_tokens: int = 0
|
||||
total_cache_creation_tokens: int = 0
|
||||
request_count: int
|
||||
cost_bearing_request_count: int = 0
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
@@ -163,12 +165,27 @@ class CostLogRow(BaseModel):
|
||||
cache_creation_tokens: int | None = None
|
||||
|
||||
|
||||
class CostBucket(BaseModel):
|
||||
bucket: str
|
||||
count: int
|
||||
|
||||
|
||||
class PlatformCostDashboard(BaseModel):
|
||||
by_provider: list[ProviderCostSummary]
|
||||
by_user: list[UserCostSummary]
|
||||
total_cost_microdollars: int
|
||||
total_requests: int
|
||||
total_users: int
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
avg_input_tokens_per_request: float = 0.0
|
||||
avg_output_tokens_per_request: float = 0.0
|
||||
avg_cost_microdollars_per_request: float = 0.0
|
||||
cost_p50_microdollars: float = 0.0
|
||||
cost_p75_microdollars: float = 0.0
|
||||
cost_p95_microdollars: float = 0.0
|
||||
cost_p99_microdollars: float = 0.0
|
||||
cost_buckets: list[CostBucket] = []
|
||||
|
||||
|
||||
def _si(row: dict, field: str) -> int:
|
||||
@@ -228,6 +245,66 @@ def _build_prisma_where(
|
||||
return where
|
||||
|
||||
|
||||
def _build_raw_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list]:
|
||||
"""Build a parameterised WHERE clause for raw SQL queries.
|
||||
|
||||
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
|
||||
source of truth for which columns are filtered and how. The first clause
|
||||
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
|
||||
explicitly provided by the caller.
|
||||
"""
|
||||
params: list = []
|
||||
clauses: list[str] = []
|
||||
idx = 1
|
||||
|
||||
# Always filter by tracking type — defaults to cost_usd for percentile /
|
||||
# bucket queries that only make sense on cost-denominated rows.
|
||||
tt = tracking_type if tracking_type is not None else "cost_usd"
|
||||
clauses.append(f'"trackingType" = ${idx}')
|
||||
params.append(tt)
|
||||
idx += 1
|
||||
|
||||
if start is not None:
|
||||
clauses.append(f'"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
|
||||
if end is not None:
|
||||
clauses.append(f'"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
|
||||
if provider is not None:
|
||||
clauses.append(f'"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
|
||||
if user_id is not None:
|
||||
clauses.append(f'"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
if model is not None:
|
||||
clauses.append(f'"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
|
||||
if block_name is not None:
|
||||
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses), params)
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
@@ -256,6 +333,14 @@ async def get_platform_cost_dashboard(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
# For per-user tracking-type breakdown we intentionally omit the
|
||||
# tracking_type filter so cost_usd and tokens rows are always present.
|
||||
# This ensures cost_bearing_request_count is correct even when the caller
|
||||
# is filtering the main view by a different tracking_type.
|
||||
where_no_tracking_type = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type=None
|
||||
)
|
||||
|
||||
sum_fields = {
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
@@ -266,13 +351,18 @@ async def get_platform_cost_dashboard(
|
||||
"trackingAmount": True,
|
||||
}
|
||||
|
||||
# Run all four aggregation queries in parallel.
|
||||
(
|
||||
by_provider_groups,
|
||||
by_user_groups,
|
||||
total_user_groups,
|
||||
total_agg_groups,
|
||||
) = await asyncio.gather(
|
||||
# Build parameterised WHERE clause for the raw SQL percentile/bucket
|
||||
# queries. Uses _build_raw_where so filter logic is shared with
|
||||
# _build_prisma_where and only maintained in one place.
|
||||
# Always force tracking_type=None here so _build_raw_where defaults to
|
||||
# "cost_usd" — percentile and histogram queries only make sense on
|
||||
# cost-denominated rows, regardless of what the caller is filtering.
|
||||
raw_where, raw_params = _build_raw_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type=None
|
||||
)
|
||||
|
||||
# Queries that always run regardless of tracking_type filter.
|
||||
common_queries = [
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
PrismaLog.prisma().group_by(
|
||||
@@ -288,20 +378,125 @@ async def get_platform_cost_dashboard(
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
# Per-user cost-bearing request count: group by (userId, trackingType)
|
||||
# so we can compute the correct denominator for per-user avg cost.
|
||||
# Uses where_no_tracking_type so cost_usd rows are always included
|
||||
# even when the caller filters the main view by a different tracking_type.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
count=True,
|
||||
),
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate: group by provider (no limit) to sum across all
|
||||
# matching rows. Summed in Python to get grand totals.
|
||||
# Total aggregate (filtered): group by (provider, trackingType) so we can
|
||||
# compute cost-bearing and token-bearing denominators for avg stats.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider"],
|
||||
by=["provider", "trackingType"],
|
||||
where=where,
|
||||
sum={"costMicrodollars": True},
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Percentile distribution of cost per request (respects all filters).
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
" percentile_cont(0.5) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p50,'
|
||||
" percentile_cont(0.75) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p75,'
|
||||
" percentile_cont(0.95) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p95,'
|
||||
" percentile_cont(0.99) WITHIN GROUP"
|
||||
' (ORDER BY "costMicrodollars") as p99'
|
||||
' FROM {schema_prefix}"PlatformCostLog"'
|
||||
f" WHERE {raw_where}",
|
||||
*raw_params,
|
||||
),
|
||||
# Histogram buckets for cost distribution (respects all filters).
|
||||
# NULL costMicrodollars is excluded explicitly to prevent such rows
|
||||
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
" CASE"
|
||||
' WHEN "costMicrodollars" < 500000'
|
||||
" THEN '$0-0.50'"
|
||||
' WHEN "costMicrodollars" < 1000000'
|
||||
" THEN '$0.50-1'"
|
||||
' WHEN "costMicrodollars" < 2000000'
|
||||
" THEN '$1-2'"
|
||||
' WHEN "costMicrodollars" < 5000000'
|
||||
" THEN '$2-5'"
|
||||
' WHEN "costMicrodollars" < 10000000'
|
||||
" THEN '$5-10'"
|
||||
" ELSE '$10+'"
|
||||
" END as bucket,"
|
||||
" COUNT(*) as count"
|
||||
' FROM {schema_prefix}"PlatformCostLog"'
|
||||
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
|
||||
" GROUP BY bucket"
|
||||
' ORDER BY MIN("costMicrodollars")',
|
||||
*raw_params,
|
||||
),
|
||||
]
|
||||
|
||||
# Only run the unfiltered aggregate query when tracking_type is set;
|
||||
# when tracking_type is None, the filtered query already contains all
|
||||
# tracking types and reusing it avoids a redundant full aggregation.
|
||||
if tracking_type is not None:
|
||||
common_queries.append(
|
||||
# Total aggregate (no tracking_type filter): used to compute
|
||||
# cost_bearing_requests and token_bearing_requests denominators so
|
||||
# global avg stats remain meaningful when the caller filters the
|
||||
# main view by a specific tracking_type (e.g. 'tokens').
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*common_queries)
|
||||
|
||||
# Unpack results by name for clarity.
|
||||
by_provider_groups = results[0]
|
||||
by_user_groups = results[1]
|
||||
by_user_tracking_groups = results[2]
|
||||
total_user_groups = results[3]
|
||||
total_agg_groups = results[4]
|
||||
percentile_rows = results[5]
|
||||
bucket_rows = results[6]
|
||||
# When tracking_type is None, the filtered and unfiltered queries are
|
||||
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
|
||||
total_agg_no_tracking_type_groups = (
|
||||
results[7] if tracking_type is not None else total_agg_groups
|
||||
)
|
||||
|
||||
# Compute token grand-totals from the unfiltered aggregate so they remain
|
||||
# consistent with the avg-token stats (which also use unfiltered data).
|
||||
# Using by_provider_groups here would give 0 tokens when tracking_type='cost_usd'
|
||||
# because cost_usd rows carry no token data, contradicting non-zero averages.
|
||||
total_input_tokens = sum(
|
||||
_si(r, "inputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
total_output_tokens = sum(
|
||||
_si(r, "outputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
|
||||
@@ -328,6 +523,61 @@ async def get_platform_cost_dashboard(
|
||||
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
|
||||
total_requests = sum(_ca(r) for r in total_agg_groups)
|
||||
|
||||
# Extract percentile values from the raw query result.
|
||||
pctl = percentile_rows[0] if percentile_rows else {}
|
||||
cost_p50 = float(pctl.get("p50") or 0)
|
||||
cost_p75 = float(pctl.get("p75") or 0)
|
||||
cost_p95 = float(pctl.get("p95") or 0)
|
||||
cost_p99 = float(pctl.get("p99") or 0)
|
||||
|
||||
# Build cost bucket list.
|
||||
cost_buckets: list[CostBucket] = [
|
||||
CostBucket(bucket=r["bucket"], count=int(r["count"])) for r in bucket_rows
|
||||
]
|
||||
|
||||
# Avg-stat numerators and denominators are derived from the unfiltered
|
||||
# aggregate so they remain meaningful when the caller filters by a specific
|
||||
# tracking_type. Example: filtering by 'tokens' excludes cost_usd rows from
|
||||
# total_agg_groups, so avg_cost would always be 0 if we used that; using
|
||||
# total_agg_no_tracking_type_groups gives the correct cost_usd total/count.
|
||||
avg_cost_total = sum(
|
||||
_si(r, "costMicrodollars")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
cost_bearing_requests = sum(
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
avg_input_total = sum(
|
||||
_si(r, "inputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
avg_output_total = sum(
|
||||
_si(r, "outputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
# Token-bearing request count: only rows where trackingType == "tokens".
|
||||
# Token averages must use this denominator; cost_usd rows do not carry tokens.
|
||||
token_bearing_requests = sum(
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Per-user cost-bearing request count: used for per-user avg cost so the
|
||||
# denominator matches the numerator (cost_usd rows only, per user).
|
||||
user_cost_bearing_counts: dict[str, int] = {}
|
||||
for r in by_user_tracking_groups:
|
||||
if r.get("trackingType") == "cost_usd" and r.get("userId"):
|
||||
uid = r["userId"]
|
||||
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
|
||||
r
|
||||
)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
@@ -355,12 +605,35 @@ async def get_platform_cost_dashboard(
|
||||
total_cache_read_tokens=_si(r, "cacheReadTokens"),
|
||||
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
|
||||
request_count=_ca(r),
|
||||
cost_bearing_request_count=user_cost_bearing_counts.get(
|
||||
r.get("userId") or "", 0
|
||||
),
|
||||
)
|
||||
for r in by_user_groups
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
total_users=total_users,
|
||||
total_input_tokens=total_input_tokens,
|
||||
total_output_tokens=total_output_tokens,
|
||||
avg_input_tokens_per_request=(
|
||||
avg_input_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_output_tokens_per_request=(
|
||||
avg_output_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_cost_microdollars_per_request=(
|
||||
avg_cost_total / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
|
||||
),
|
||||
cost_p50_microdollars=cost_p50,
|
||||
cost_p75_microdollars=cost_p75,
|
||||
cost_p95_microdollars=cost_p95,
|
||||
cost_p99_microdollars=cost_p99,
|
||||
cost_buckets=cost_buckets,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from backend.util.json import SafeJson
|
||||
|
||||
from .platform_cost import (
|
||||
PlatformCostEntry,
|
||||
_build_prisma_where,
|
||||
_build_raw_where,
|
||||
_build_where,
|
||||
_mask_email,
|
||||
get_platform_cost_dashboard,
|
||||
@@ -156,6 +158,84 @@ class TestBuildWhere:
|
||||
assert 'p."trackingType" = $3' in sql
|
||||
|
||||
|
||||
class TestBuildPrismaWhere:
|
||||
def test_both_start_and_end(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(start, end, None, None)
|
||||
assert where["createdAt"] == {"gte": start, "lte": end}
|
||||
|
||||
def test_end_only(self):
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(None, end, None, None)
|
||||
assert where["createdAt"] == {"lte": end}
|
||||
|
||||
def test_start_only(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
where = _build_prisma_where(start, None, None, None)
|
||||
assert where["createdAt"] == {"gte": start}
|
||||
|
||||
def test_no_filters(self):
|
||||
where = _build_prisma_where(None, None, None, None)
|
||||
assert "createdAt" not in where
|
||||
|
||||
def test_provider_lowercased(self):
|
||||
where = _build_prisma_where(None, None, "OpenAI", None)
|
||||
assert where["provider"] == "openai"
|
||||
|
||||
def test_model_filter(self):
|
||||
where = _build_prisma_where(None, None, None, None, model="gpt-4")
|
||||
assert where["model"] == "gpt-4"
|
||||
|
||||
def test_block_name_case_insensitive(self):
|
||||
where = _build_prisma_where(None, None, None, None, block_name="LLMBlock")
|
||||
assert where["blockName"] == {"equals": "LLMBlock", "mode": "insensitive"}
|
||||
|
||||
def test_tracking_type(self):
|
||||
where = _build_prisma_where(None, None, None, None, tracking_type="tokens")
|
||||
assert where["trackingType"] == "tokens"
|
||||
|
||||
|
||||
class TestBuildRawWhere:
|
||||
def test_end_filter(self):
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_raw_where(None, end, None, None)
|
||||
assert '"createdAt" <= $2::timestamptz' in sql
|
||||
assert end in params
|
||||
|
||||
def test_model_filter(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, model="gpt-4")
|
||||
assert '"model" = $' in sql
|
||||
assert "gpt-4" in params
|
||||
|
||||
def test_block_name_filter(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, block_name="LLMBlock")
|
||||
assert 'LOWER("blockName") = LOWER($' in sql
|
||||
assert "LLMBlock" in params
|
||||
|
||||
def test_all_filters_combined(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
|
||||
sql, params = _build_raw_where(
|
||||
start, end, "anthropic", "u1", model="claude-3", block_name="LLM"
|
||||
)
|
||||
# trackingType (default), start, end, provider, user_id, model, block_name
|
||||
assert len(params) == 7
|
||||
assert "anthropic" in params
|
||||
assert "u1" in params
|
||||
assert "claude-3" in params
|
||||
assert "LLM" in params
|
||||
|
||||
def test_default_tracking_type_is_cost_usd(self):
|
||||
sql, params = _build_raw_where(None, None, None, None)
|
||||
assert '"trackingType" = $1' in sql
|
||||
assert params[0] == "cost_usd"
|
||||
|
||||
def test_explicit_tracking_type_overrides_default(self):
|
||||
sql, params = _build_raw_where(None, None, None, None, tracking_type="tokens")
|
||||
assert params[0] == "tokens"
|
||||
|
||||
|
||||
def _make_entry(**overrides: object) -> PlatformCostEntry:
|
||||
return PlatformCostEntry.model_validate(
|
||||
{
|
||||
@@ -286,8 +366,9 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups (no cost_usd rows for this user)
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (tracking_type=None → same as unfiltered)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[mock_user])
|
||||
@@ -301,6 +382,14 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
[{"p50": 1000, "p75": 2000, "p95": 4000, "p99": 5000}],
|
||||
[{"bucket": "$0-0.50", "count": 3}],
|
||||
],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -313,6 +402,131 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.by_provider[0].total_duration_seconds == 10.5
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
assert dashboard.cost_p50_microdollars == 1000
|
||||
assert dashboard.cost_p75_microdollars == 2000
|
||||
assert dashboard.cost_p95_microdollars == 4000
|
||||
assert dashboard.cost_p99_microdollars == 5000
|
||||
assert len(dashboard.cost_buckets) == 1
|
||||
# total_input/output_tokens come from total_agg_no_tracking_type_groups
|
||||
# (provider_row has 1000/500)
|
||||
assert dashboard.total_input_tokens == 1000
|
||||
assert dashboard.total_output_tokens == 500
|
||||
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
|
||||
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 3)
|
||||
# No cost_usd rows in total_agg → avg_cost should be 0
|
||||
assert dashboard.avg_cost_microdollars_per_request == 0.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_bearing_request_count_nonzero_when_filtering_by_tokens(self):
|
||||
"""When filtering by tracking_type='tokens', cost_bearing_request_count
|
||||
must still reflect cost_usd rows because by_user_tracking_groups is
|
||||
queried without the tracking_type constraint."""
|
||||
# total_agg only has a tokens row (because of the tracking_type filter)
|
||||
total_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="tokens", cost=0, count=5
|
||||
)
|
||||
# by_user_tracking_groups returns BOTH rows (no tracking_type filter)
|
||||
user_tracking_cost_usd_row = {
|
||||
"_count": {"_all": 7},
|
||||
"userId": "u1",
|
||||
"trackingType": "cost_usd",
|
||||
}
|
||||
user_tracking_tokens_row = {
|
||||
"_count": {"_all": 5},
|
||||
"userId": "u1",
|
||||
"trackingType": "tokens",
|
||||
}
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[total_row], # by_provider
|
||||
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
|
||||
[
|
||||
user_tracking_cost_usd_row,
|
||||
user_tracking_tokens_row,
|
||||
], # by_user_tracking
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[total_row], # total agg (filtered)
|
||||
[total_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# by_user has 1 user with 5 total requests (tokens rows only due to filter)
|
||||
# but per-user cost_bearing count should be 7 (from cost_usd rows in
|
||||
# by_user_tracking_groups which uses where_no_tracking_type)
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].cost_bearing_request_count == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_avg_cost_nonzero_when_filtering_by_tokens(self):
|
||||
"""When filtering by tracking_type='tokens', avg_cost_microdollars_per_request
|
||||
must still reflect cost_usd rows from total_agg_no_tracking_type_groups,
|
||||
not the filtered total_agg_groups which only has tokens rows."""
|
||||
# filtered total_agg only has tokens rows (zero cost)
|
||||
tokens_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="tokens", cost=0, count=5
|
||||
)
|
||||
# unfiltered total_agg has both rows (cost_usd carries the actual cost)
|
||||
cost_usd_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="cost_usd", cost=10_000, count=4
|
||||
)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[tokens_row], # by_provider
|
||||
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[tokens_row], # total agg (filtered — tokens only)
|
||||
[tokens_row, cost_usd_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# avg_cost_microdollars_per_request must be non-zero: cost_usd row
|
||||
# (10_000 microdollars, 4 requests) is present in the unfiltered agg.
|
||||
assert dashboard.avg_cost_microdollars_per_request == pytest.approx(10_000 / 4)
|
||||
# avg token stats use token_bearing_requests from unfiltered agg (5)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 5)
|
||||
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_aggregated_not_hardcoded(self):
|
||||
@@ -335,8 +549,9 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (tracking_type=None → same as unfiltered)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
@@ -350,6 +565,14 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[
|
||||
[{"p50": 0, "p75": 0, "p95": 0, "p99": 0}],
|
||||
[],
|
||||
],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -361,7 +584,7 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -373,6 +596,11 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
@@ -381,13 +609,56 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.total_users == 0
|
||||
assert dashboard.by_provider == []
|
||||
assert dashboard.by_user == []
|
||||
assert dashboard.cost_p50_microdollars == 0
|
||||
assert dashboard.cost_buckets == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
raw_mock = AsyncMock(side_effect=[[], []])
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
raw_mock,
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
|
||||
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users,
|
||||
# total agg filtered); the 6th call (total agg no-tracking-type) only runs
|
||||
# when tracking_type is set.
|
||||
assert mock_actions.group_by.await_count == 5
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
# Raw SQL queries should receive provider and user_id as parameters
|
||||
assert raw_mock.await_count == 2
|
||||
raw_call_args = raw_mock.call_args_list[0][0] # positional args of 1st call
|
||||
raw_params = raw_call_args[1:] # first arg is the query template
|
||||
assert "openai" in raw_params
|
||||
assert "u1" in raw_params
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_tracking_groups_excludes_tracking_type_filter(self):
|
||||
"""by_user_tracking_groups must NOT apply the tracking_type filter so that
|
||||
cost_usd rows are always included even when the caller filters by 'tokens'."""
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -399,16 +670,23 @@ class TestGetPlatformCostDashboard:
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# group_by called 4 times (by_provider, by_user, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 4
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
# Call index 2 is by_user_tracking_groups (0=by_provider, 1=by_user,
|
||||
# 2=by_user_tracking, 3=distinct_users, 4=total_agg, 5=total_agg_no_tt).
|
||||
tracking_call_where = mock_actions.group_by.call_args_list[2][1]["where"]
|
||||
# The main filter applies trackingType; by_user_tracking must NOT.
|
||||
assert "trackingType" not in tracking_call_where
|
||||
# Other filters (e.g., date range, provider) are still passed through.
|
||||
# The first call (by_provider) should have trackingType in its where dict.
|
||||
provider_call_where = mock_actions.group_by.call_args_list[0][1]["where"]
|
||||
assert "trackingType" in provider_call_where
|
||||
|
||||
|
||||
def _make_prisma_log_row(
|
||||
|
||||
509
autogpt_platform/backend/backend/executor/billing.py
Normal file
509
autogpt_platform/backend/backend/executor/billing.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import Block
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.execution import (
|
||||
ExecutionStatus,
|
||||
GraphExecutionEntry,
|
||||
NodeExecutionEntry,
|
||||
)
|
||||
from backend.data.graph import Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util.clients import (
|
||||
get_database_manager_client,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .utils import LogMetadata, block_usage_cost, execution_usage_cost
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.data.db_manager import DatabaseManagerClient
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
logger = TruncatedLogger(_logger, prefix="[Billing]")
|
||||
settings = Settings()
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
# Hard cap on the multiplier passed to charge_extra_runtime_cost to
|
||||
# protect against a corrupted llm_call_count draining a user's balance.
|
||||
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
|
||||
# 200 leaves headroom while preventing runaway charges.
|
||||
_MAX_EXTRA_RUNTIME_COST = 200
|
||||
|
||||
|
||||
def get_db_client() -> "DatabaseManagerClient":
|
||||
return get_database_manager_client()
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def resolve_block_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple["Block | None", int, dict[str, Any]]:
|
||||
"""Look up the block and compute its base usage cost for an exec.
|
||||
|
||||
Shared by charge_usage and charge_extra_runtime_cost so the
|
||||
(get_block, block_usage_cost) lookup lives in exactly one place.
|
||||
Returns ``(block, cost, matching_filter)``. ``block`` is ``None`` if
|
||||
the block id can't be resolved — callers should treat that as
|
||||
"nothing to charge".
|
||||
"""
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return None, 0, {}
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.inputs)
|
||||
return block, cost, matching_filter
|
||||
|
||||
|
||||
def charge_usage(
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block, cost, matching_filter = resolve_block_cost(node_exec)
|
||||
if not block:
|
||||
return total_cost, 0
|
||||
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
# execution_count=0 is used by charge_node_usage for nested tool calls
|
||||
# which must not be pushed into higher execution-count tiers.
|
||||
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
|
||||
# so skip it entirely when execution_count is 0.
|
||||
cost, usage_count = (
|
||||
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost, remaining_balance
|
||||
|
||||
|
||||
def _charge_extra_runtime_cost_sync(
|
||||
node_exec: NodeExecutionEntry,
|
||||
capped_count: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Synchronous implementation — runs in a thread-pool worker.
|
||||
|
||||
Called only from charge_extra_runtime_cost. Do not call directly from
|
||||
async code.
|
||||
|
||||
Note: ``resolve_block_cost`` is called again here (rather than reusing
|
||||
the result from ``charge_usage`` at the start of execution) because the
|
||||
two calls happen in separate thread-pool workers and sharing mutable
|
||||
state across workers would require locks. The block config is immutable
|
||||
during a run, so the repeated lookup is safe and produces the same cost;
|
||||
the only overhead is an extra registry lookup.
|
||||
"""
|
||||
db_client = get_db_client()
|
||||
block, cost, matching_filter = resolve_block_cost(node_exec)
|
||||
if not block or cost <= 0:
|
||||
return 0, 0
|
||||
total_extra_cost = cost * capped_count
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=total_extra_cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input={
|
||||
**matching_filter,
|
||||
"extra_runtime_cost_count": capped_count,
|
||||
},
|
||||
reason=(
|
||||
f"Extra agent-mode iterations for {block.name} "
|
||||
f"({capped_count} additional LLM calls)"
|
||||
),
|
||||
),
|
||||
)
|
||||
return total_extra_cost, remaining_balance
|
||||
|
||||
|
||||
async def charge_extra_runtime_cost(
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_count: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Charge a block extra runtime cost beyond the initial run.
|
||||
|
||||
Used by agent-mode blocks (e.g. OrchestratorBlock) that make multiple
|
||||
LLM calls within a single node execution. The first iteration is already
|
||||
charged by charge_usage; this method charges *extra_count* additional
|
||||
copies of the block's base cost.
|
||||
|
||||
Returns ``(total_extra_cost, remaining_balance)``. May raise
|
||||
``InsufficientBalanceError`` if the user can't afford the charge.
|
||||
"""
|
||||
if extra_count <= 0:
|
||||
return 0, 0
|
||||
# Cap to protect against a corrupted llm_call_count.
|
||||
capped = min(extra_count, _MAX_EXTRA_RUNTIME_COST)
|
||||
if extra_count > _MAX_EXTRA_RUNTIME_COST:
|
||||
logger.warning(
|
||||
f"extra_count {extra_count} exceeds cap {_MAX_EXTRA_RUNTIME_COST};"
|
||||
f" charging {_MAX_EXTRA_RUNTIME_COST} (llm_call_count may be corrupted)"
|
||||
)
|
||||
return await asyncio.to_thread(_charge_extra_runtime_cost_sync, node_exec, capped)
|
||||
|
||||
|
||||
async def charge_node_usage(node_exec: NodeExecutionEntry) -> tuple[int, int]:
|
||||
"""Charge a single node execution to the user.
|
||||
|
||||
Public async wrapper around charge_usage for blocks (e.g. the
|
||||
OrchestratorBlock) that spawn nested node executions outside the main
|
||||
queue and therefore need to charge them explicitly.
|
||||
|
||||
Also handles low-balance notification so callers don't need to touch
|
||||
private functions directly.
|
||||
|
||||
Note: this **does not** increment the global execution counter
|
||||
(``increment_execution_count``). Nested tool executions are sub-steps
|
||||
of a single block run from the user's perspective and should not push
|
||||
them into higher per-execution cost tiers.
|
||||
"""
|
||||
|
||||
def _run():
|
||||
total_cost, remaining = charge_usage(node_exec, 0)
|
||||
if total_cost > 0:
|
||||
handle_low_balance(
|
||||
get_db_client(), node_exec.user_id, remaining, total_cost
|
||||
)
|
||||
return total_cost, remaining
|
||||
|
||||
return await asyncio.to_thread(_run)
|
||||
|
||||
|
||||
async def try_send_insufficient_funds_notif(
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
error: InsufficientBalanceError,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Send an insufficient-funds notification, swallowing failures."""
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
handle_insufficient_funds_notif,
|
||||
get_db_client(),
|
||||
user_id,
|
||||
graph_id,
|
||||
error,
|
||||
)
|
||||
except Exception as notif_error: # pragma: no cover
|
||||
log_metadata.warning(
|
||||
f"Failed to send insufficient funds notification: {notif_error}"
|
||||
)
|
||||
|
||||
|
||||
async def handle_post_execution_billing(
|
||||
node: Node,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats,
|
||||
status: ExecutionStatus,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Charge extra runtime cost for blocks that opt into per-LLM-call billing.
|
||||
|
||||
The first LLM call is already covered by charge_usage(); each additional
|
||||
call costs another base_cost. Skipped for dry runs and failed runs.
|
||||
|
||||
InsufficientBalanceError here is a post-hoc billing leak: the work is
|
||||
already done but the user can no longer pay. The run stays COMPLETED and
|
||||
the error is logged with ``billing_leak: True`` for alerting.
|
||||
"""
|
||||
extra_iterations = (
|
||||
cast(Block, node.block).extra_runtime_cost(execution_stats)
|
||||
if status == ExecutionStatus.COMPLETED
|
||||
and not node_exec.execution_context.dry_run
|
||||
else 0
|
||||
)
|
||||
if extra_iterations <= 0:
|
||||
return
|
||||
|
||||
try:
|
||||
extra_cost, remaining_balance = await charge_extra_runtime_cost(
|
||||
node_exec,
|
||||
extra_iterations,
|
||||
)
|
||||
if extra_cost > 0:
|
||||
execution_stats.extra_cost += extra_cost
|
||||
await asyncio.to_thread(
|
||||
handle_low_balance,
|
||||
get_db_client(),
|
||||
node_exec.user_id,
|
||||
remaining_balance,
|
||||
extra_cost,
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
log_metadata.error(
|
||||
"billing_leak: insufficient balance after "
|
||||
f"{node.block.name} completed {extra_iterations} "
|
||||
f"extra iterations",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_runtime_cost_count": extra_iterations,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
# Do NOT set execution_stats.error — the node ran to completion,
|
||||
# only the post-hoc charge failed. See class-level billing-leak
|
||||
# contract documentation.
|
||||
await try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
e,
|
||||
log_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
log_metadata.error(
|
||||
f"billing_leak: failed to charge extra iterations for {node.block.name}",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_runtime_cost_count": extra_iterations,
|
||||
"error_type": type(e).__name__,
|
||||
"error": str(e),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
|
||||
def handle_agent_run_notif(
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
) -> None:
|
||||
metadata = db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
key: value[0] if key == "name" else value
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def handle_insufficient_funds_notif(
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
) -> None:
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(f"Failed to send insufficient funds Discord alert: {alert_error}")
|
||||
|
||||
|
||||
def handle_low_balance(
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
) -> None:
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send low balance Discord alert: {e}")
|
||||
@@ -21,11 +21,9 @@ from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
from backend.data import redis_client as redis
|
||||
from backend.data.block import BlockInput, BlockOutput, BlockOutputEntry
|
||||
from backend.data.credit import UsageTransactionMetadata
|
||||
from backend.data.dynamic_fields import parse_execution_output
|
||||
from backend.data.execution import (
|
||||
ExecutionContext,
|
||||
@@ -39,27 +37,18 @@ from backend.data.execution import (
|
||||
)
|
||||
from backend.data.graph import Link, Node
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventModel,
|
||||
NotificationType,
|
||||
ZeroBalanceData,
|
||||
)
|
||||
from backend.data.rabbitmq import SyncRabbitMQ
|
||||
from backend.executor.cost_tracking import (
|
||||
drain_pending_cost_logs,
|
||||
log_system_credential_cost,
|
||||
)
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.notifications.notifications import queue_notification
|
||||
from backend.util import json
|
||||
from backend.util.clients import (
|
||||
get_async_execution_event_bus,
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_execution_event_bus,
|
||||
get_notification_manager_client,
|
||||
)
|
||||
from backend.util.decorator import (
|
||||
async_error_logged,
|
||||
@@ -75,7 +64,6 @@ from backend.util.exceptions import (
|
||||
)
|
||||
from backend.util.file import clean_exec_files
|
||||
from backend.util.logging import TruncatedLogger, configure_logging
|
||||
from backend.util.metrics import DiscordChannel
|
||||
from backend.util.process import AppProcess, set_service_name
|
||||
from backend.util.retry import (
|
||||
continuous_retry,
|
||||
@@ -84,6 +72,7 @@ from backend.util.retry import (
|
||||
)
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import billing
|
||||
from .activity_status_generator import generate_activity_status_for_execution
|
||||
from .automod.manager import automod_manager
|
||||
from .cluster_lock import ClusterLock
|
||||
@@ -98,9 +87,7 @@ from .utils import (
|
||||
ExecutionOutputEntry,
|
||||
LogMetadata,
|
||||
NodeExecutionProgress,
|
||||
block_usage_cost,
|
||||
create_execution_queue_config,
|
||||
execution_usage_cost,
|
||||
validate_exec,
|
||||
)
|
||||
|
||||
@@ -126,40 +113,6 @@ utilization_gauge = Gauge(
|
||||
"Ratio of active graph runs to max graph workers",
|
||||
)
|
||||
|
||||
# Redis key prefix for tracking insufficient funds Discord notifications.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX = "insufficient_funds_discord_notified"
|
||||
# TTL for the notification flag (30 days) - acts as a fallback cleanup
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS = 30 * 24 * 60 * 60
|
||||
|
||||
|
||||
async def clear_insufficient_funds_notifications(user_id: str) -> int:
|
||||
"""
|
||||
Clear all insufficient funds notification flags for a user.
|
||||
|
||||
This should be called when a user tops up their credits, allowing
|
||||
Discord notifications to be sent again if they run out of funds.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to clear notifications for.
|
||||
|
||||
Returns:
|
||||
The number of keys that were deleted.
|
||||
"""
|
||||
try:
|
||||
redis_client = await redis.get_redis_async()
|
||||
pattern = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:*"
|
||||
keys = [key async for key in redis_client.scan_iter(match=pattern)]
|
||||
if keys:
|
||||
return await redis_client.delete(*keys)
|
||||
return 0
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to clear insufficient funds notification flags for user "
|
||||
f"{user_id}: {e}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
# Thread-local storage for ExecutionProcessor instances
|
||||
_tls = threading.local()
|
||||
@@ -681,12 +634,16 @@ class ExecutionProcessor:
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
|
||||
await billing.handle_post_execution_billing(
|
||||
node, node_exec, execution_stats, status, log_metadata
|
||||
)
|
||||
|
||||
graph_stats, graph_stats_lock = graph_stats_pair
|
||||
with graph_stats_lock:
|
||||
graph_stats.node_count += 1 + execution_stats.extra_steps
|
||||
graph_stats.nodes_cputime += execution_stats.cputime
|
||||
graph_stats.nodes_walltime += execution_stats.walltime
|
||||
graph_stats.cost += execution_stats.extra_cost
|
||||
graph_stats.cost += execution_stats.cost + execution_stats.extra_cost
|
||||
if isinstance(execution_stats.error, Exception):
|
||||
graph_stats.node_error_count += 1
|
||||
|
||||
@@ -716,6 +673,18 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
# If the node failed because a nested tool charge raised IBE,
|
||||
# send the user notification so they understand why the run stopped.
|
||||
if status == ExecutionStatus.FAILED and isinstance(
|
||||
execution_stats.error, InsufficientBalanceError
|
||||
):
|
||||
await billing.try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
execution_stats.error,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
@async_time_measured
|
||||
@@ -935,7 +904,7 @@ class ExecutionProcessor:
|
||||
)
|
||||
finally:
|
||||
# Communication handling
|
||||
self._handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
billing.handle_agent_run_notif(db_client, graph_exec, exec_stats)
|
||||
|
||||
update_graph_execution_state(
|
||||
db_client=db_client,
|
||||
@@ -944,57 +913,18 @@ class ExecutionProcessor:
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
def _charge_usage(
|
||||
async def charge_node_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
) -> tuple[int, int]:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost, 0
|
||||
return await billing.charge_node_usage(node_exec)
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input=matching_filter,
|
||||
reason=f"Ran block {node_exec.block_id} {block.name}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
return total_cost, remaining_balance
|
||||
async def charge_extra_runtime_cost(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_count: int,
|
||||
) -> tuple[int, int]:
|
||||
return await billing.charge_extra_runtime_cost(node_exec, extra_count)
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -1106,7 +1036,7 @@ class ExecutionProcessor:
|
||||
# Charge usage (may raise) — skipped for dry runs
|
||||
try:
|
||||
if not graph_exec.execution_context.dry_run:
|
||||
cost, remaining_balance = self._charge_usage(
|
||||
cost, remaining_balance = billing.charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=increment_execution_count(
|
||||
graph_exec.user_id
|
||||
@@ -1115,7 +1045,7 @@ class ExecutionProcessor:
|
||||
with execution_stats_lock:
|
||||
execution_stats.cost += cost
|
||||
# Check if we crossed the low balance threshold
|
||||
self._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=db_client,
|
||||
user_id=graph_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
@@ -1135,7 +1065,7 @@ class ExecutionProcessor:
|
||||
status=ExecutionStatus.FAILED,
|
||||
)
|
||||
|
||||
self._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client,
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
@@ -1397,165 +1327,6 @@ class ExecutionProcessor:
|
||||
):
|
||||
execution_queue.add(next_execution)
|
||||
|
||||
def _handle_agent_run_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = db_client.get_node_executions(
|
||||
graph_exec.graph_exec_id,
|
||||
block_ids=[AgentOutputBlock().id],
|
||||
)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
key: value[0] if key == "name" else value
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_insufficient_funds_notif(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
# Check if we've already sent a notification for this user+agent combo.
|
||||
# We only send one notification per user per agent until they top up credits.
|
||||
redis_key = f"{INSUFFICIENT_FUNDS_NOTIFIED_PREFIX}:{user_id}:{graph_id}"
|
||||
try:
|
||||
redis_client = redis.get_redis()
|
||||
# SET NX returns True only if the key was newly set (didn't exist)
|
||||
is_new_notification = redis_client.set(
|
||||
redis_key,
|
||||
"1",
|
||||
nx=True,
|
||||
ex=INSUFFICIENT_FUNDS_NOTIFIED_TTL_SECONDS,
|
||||
)
|
||||
if not is_new_notification:
|
||||
# Already notified for this user+agent, skip all notifications
|
||||
logger.debug(
|
||||
f"Skipping duplicate insufficient funds notification for "
|
||||
f"user={user_id}, graph={graph_id}"
|
||||
)
|
||||
return
|
||||
except Exception as redis_error:
|
||||
# If Redis fails, log and continue to send the notification
|
||||
# (better to occasionally duplicate than to never notify)
|
||||
logger.warning(
|
||||
f"Failed to check/set insufficient funds notification flag in Redis: "
|
||||
f"{redis_error}"
|
||||
)
|
||||
|
||||
shortfall = abs(e.amount) - e.balance
|
||||
metadata = db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
|
||||
# Queue user email notification
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.ZERO_BALANCE,
|
||||
data=ZeroBalanceData(
|
||||
current_balance=e.balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Send Discord system alert
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
|
||||
alert_message = (
|
||||
f"❌ **Insufficient Funds Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Agent: {metadata.name if metadata else 'Unknown Agent'}\n"
|
||||
f"Current balance: ${e.balance / 100:.2f}\n"
|
||||
f"Attempted cost: ${abs(e.amount) / 100:.2f}\n"
|
||||
f"Shortfall: ${abs(shortfall) / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as alert_error:
|
||||
logger.error(
|
||||
f"Failed to send insufficient funds Discord alert: {alert_error}"
|
||||
)
|
||||
|
||||
def _handle_low_balance(
|
||||
self,
|
||||
db_client: "DatabaseManagerClient",
|
||||
user_id: str,
|
||||
current_balance: int,
|
||||
transaction_cost: int,
|
||||
):
|
||||
"""Check and handle low balance scenarios after a transaction"""
|
||||
LOW_BALANCE_THRESHOLD = settings.config.low_balance_threshold
|
||||
|
||||
balance_before = current_balance + transaction_cost
|
||||
|
||||
if (
|
||||
current_balance < LOW_BALANCE_THRESHOLD
|
||||
and balance_before >= LOW_BALANCE_THRESHOLD
|
||||
):
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
queue_notification(
|
||||
NotificationEventModel(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=current_balance,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
user_email = db_client.get_user_email_by_id(user_id)
|
||||
alert_message = (
|
||||
f"⚠️ **Low Balance Alert**\n"
|
||||
f"User: {user_email or user_id}\n"
|
||||
f"Balance dropped below ${LOW_BALANCE_THRESHOLD / 100:.2f}\n"
|
||||
f"Current balance: ${current_balance / 100:.2f}\n"
|
||||
f"Transaction cost: ${transaction_cost / 100:.2f}\n"
|
||||
f"[View User Details]({base_url}/admin/spending?search={user_email})"
|
||||
)
|
||||
get_notification_manager_client().discord_system_alert(
|
||||
alert_message, DiscordChannel.PRODUCT
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to send low balance Discord alert: {e}")
|
||||
|
||||
|
||||
class ExecutionManager(AppProcess):
|
||||
def __init__(self):
|
||||
|
||||
@@ -4,9 +4,9 @@ import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import ZeroBalanceData
|
||||
from backend.executor.manager import (
|
||||
from backend.executor import billing
|
||||
from backend.executor.billing import (
|
||||
INSUFFICIENT_FUNDS_NOTIFIED_PREFIX,
|
||||
ExecutionProcessor,
|
||||
clear_insufficient_funds_notifications,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
@@ -25,7 +25,6 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
):
|
||||
"""Test that the first insufficient funds notification sends a Discord alert."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -36,13 +35,13 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
@@ -63,7 +62,7 @@ async def test_handle_insufficient_funds_sends_discord_alert_first_time(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -99,7 +98,6 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
):
|
||||
"""Test that duplicate insufficient funds notifications skip both email and Discord."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -110,13 +108,13 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Setup mocks
|
||||
@@ -134,7 +132,7 @@ async def test_handle_insufficient_funds_skips_duplicate_notifications(
|
||||
mock_db_client.get_graph_metadata.return_value = MagicMock(name="Test Agent")
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -154,7 +152,6 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
):
|
||||
"""Test that different agents for the same user get separate Discord alerts."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id_1 = "test-graph-111"
|
||||
graph_id_2 = "test-graph-222"
|
||||
@@ -166,12 +163,12 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
amount=-714,
|
||||
)
|
||||
|
||||
with patch("backend.executor.manager.queue_notification"), patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
with patch("backend.executor.billing.queue_notification"), patch(
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -190,7 +187,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# First agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_1,
|
||||
@@ -198,7 +195,7 @@ async def test_handle_insufficient_funds_different_agents_get_separate_alerts(
|
||||
)
|
||||
|
||||
# Second agent notification
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id_2,
|
||||
@@ -227,7 +224,7 @@ async def test_clear_insufficient_funds_notifications(server: SpinTestServer):
|
||||
|
||||
user_id = "test-user-123"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
@@ -263,7 +260,7 @@ async def test_clear_insufficient_funds_notifications_no_keys(server: SpinTestSe
|
||||
|
||||
user_id = "test-user-no-notifications"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
mock_redis_client = MagicMock()
|
||||
# get_redis_async is an async function, so we need AsyncMock for it
|
||||
@@ -290,7 +287,7 @@ async def test_clear_insufficient_funds_notifications_handles_redis_error(
|
||||
|
||||
user_id = "test-user-redis-error"
|
||||
|
||||
with patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
with patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
# Mock get_redis_async to raise an error
|
||||
mock_redis_module.get_redis_async = AsyncMock(
|
||||
@@ -310,7 +307,6 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
):
|
||||
"""Test that both email and Discord notifications are still sent when Redis fails."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
graph_id = "test-graph-456"
|
||||
error = InsufficientBalanceError(
|
||||
@@ -321,13 +317,13 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -346,7 +342,7 @@ async def test_handle_insufficient_funds_continues_on_redis_error(
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the insufficient funds handler
|
||||
execution_processor._handle_insufficient_funds_notif(
|
||||
billing.handle_insufficient_funds_notif(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
graph_id=graph_id,
|
||||
@@ -370,7 +366,7 @@ async def test_add_transaction_clears_notifications_on_grant(server: SpinTestSer
|
||||
user_id = "test-user-grant-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -412,7 +408,7 @@ async def test_add_transaction_clears_notifications_on_top_up(server: SpinTestSe
|
||||
user_id = "test-user-topup-clear"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -450,7 +446,7 @@ async def test_add_transaction_skips_clearing_for_inactive_transaction(
|
||||
user_id = "test-user-inactive"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -486,7 +482,7 @@ async def test_add_transaction_skips_clearing_for_usage_transaction(
|
||||
user_id = "test-user-usage"
|
||||
|
||||
with patch("backend.data.credit.query_raw_with_schema") as mock_query, patch(
|
||||
"backend.executor.manager.redis"
|
||||
"backend.executor.billing.redis"
|
||||
) as mock_redis_module:
|
||||
|
||||
# Mock the query to return a successful transaction
|
||||
@@ -521,7 +517,7 @@ async def test_enable_transaction_clears_notifications(server: SpinTestServer):
|
||||
|
||||
with patch("backend.data.credit.CreditTransaction") as mock_credit_tx, patch(
|
||||
"backend.data.credit.query_raw_with_schema"
|
||||
) as mock_query, patch("backend.executor.manager.redis") as mock_redis_module:
|
||||
) as mock_query, patch("backend.executor.billing.redis") as mock_redis_module:
|
||||
|
||||
# Mock finding the pending transaction
|
||||
mock_transaction = MagicMock()
|
||||
|
||||
@@ -4,26 +4,25 @@ import pytest
|
||||
from prisma.enums import NotificationType
|
||||
|
||||
from backend.data.notifications import LowBalanceData
|
||||
from backend.executor.manager import ExecutionProcessor
|
||||
from backend.executor import billing
|
||||
from backend.util.test import SpinTestServer
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
"""Test that _handle_low_balance triggers notification when crossing threshold."""
|
||||
"""Test that handle_low_balance triggers notification when crossing threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 400 # $4 - below $5 threshold
|
||||
transaction_cost = 600 # $6 transaction
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -37,7 +36,7 @@ async def test_handle_low_balance_threshold_crossing(server: SpinTestServer):
|
||||
mock_db_client.get_user_email_by_id.return_value = "test@example.com"
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
@@ -69,7 +68,6 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
):
|
||||
"""Test that no notification is sent when not crossing the threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 600 # $6 - above $5 threshold
|
||||
transaction_cost = (
|
||||
@@ -78,11 +76,11 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -94,7 +92,7 @@ async def test_handle_low_balance_no_notification_when_not_crossing(
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
@@ -112,7 +110,6 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
):
|
||||
"""Test that no notification is sent when already below threshold."""
|
||||
|
||||
execution_processor = ExecutionProcessor()
|
||||
user_id = "test-user-123"
|
||||
current_balance = 300 # $3 - below $5 threshold
|
||||
transaction_cost = (
|
||||
@@ -121,11 +118,11 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
|
||||
# Mock dependencies
|
||||
with patch(
|
||||
"backend.executor.manager.queue_notification"
|
||||
"backend.executor.billing.queue_notification"
|
||||
) as mock_queue_notif, patch(
|
||||
"backend.executor.manager.get_notification_manager_client"
|
||||
"backend.executor.billing.get_notification_manager_client"
|
||||
) as mock_get_client, patch(
|
||||
"backend.executor.manager.settings"
|
||||
"backend.executor.billing.settings"
|
||||
) as mock_settings:
|
||||
|
||||
# Setup mocks
|
||||
@@ -137,7 +134,7 @@ async def test_handle_low_balance_no_duplicate_when_already_below(
|
||||
mock_db_client = MagicMock()
|
||||
|
||||
# Test the low balance handler
|
||||
execution_processor._handle_low_balance(
|
||||
billing.handle_low_balance(
|
||||
db_client=mock_db_client,
|
||||
user_id=user_id,
|
||||
current_balance=current_balance,
|
||||
|
||||
@@ -18,9 +18,13 @@ images: {
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import prisma.enums as prisma_enums
|
||||
import prisma.models as prisma_models
|
||||
from faker import Faker
|
||||
|
||||
# Import API functions from the backend
|
||||
@@ -30,10 +34,12 @@ from backend.api.features.store.db import (
|
||||
create_store_submission,
|
||||
review_store_submission,
|
||||
)
|
||||
from backend.api.features.store.model import StoreSubmission
|
||||
from backend.blocks.io import AgentInputBlock
|
||||
from backend.data.auth.api_key import create_api_key
|
||||
from backend.data.credit import get_user_credit_model
|
||||
from backend.data.db import prisma
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
from backend.data.graph import Graph, Link, Node, create_graph, make_graph_model
|
||||
from backend.data.user import get_or_create_user
|
||||
from backend.util.clients import get_supabase
|
||||
|
||||
@@ -60,6 +66,31 @@ MAX_REVIEWS_PER_VERSION = 5
|
||||
GUARANTEED_FEATURED_AGENTS = 8
|
||||
GUARANTEED_FEATURED_CREATORS = 5
|
||||
GUARANTEED_TOP_AGENTS = 10
|
||||
E2E_MARKETPLACE_CREATOR_EMAIL = "test123@example.com"
|
||||
E2E_MARKETPLACE_CREATOR_USERNAME = "e2e-marketplace"
|
||||
E2E_MARKETPLACE_AGENT_SLUG = "e2e-calculator-agent"
|
||||
E2E_MARKETPLACE_AGENT_NAME = "E2E Calculator Agent"
|
||||
E2E_MARKETPLACE_AGENT_INPUT_VALUE = 8
|
||||
E2E_MARKETPLACE_AGENT_OUTPUT_VALUE = 42
|
||||
_LOCAL_TEMPLATE_PATH = (
|
||||
Path(__file__).resolve().parents[1] / "agents" / "calculator-agent.json"
|
||||
)
|
||||
_DOCKER_TEMPLATE_PATH = Path(
|
||||
"/app/autogpt_platform/backend/agents/calculator-agent.json"
|
||||
)
|
||||
E2E_MARKETPLACE_AGENT_TEMPLATE_PATH = (
|
||||
_LOCAL_TEMPLATE_PATH if _LOCAL_TEMPLATE_PATH.exists() else _DOCKER_TEMPLATE_PATH
|
||||
)
|
||||
SEEDED_TEST_EMAILS = [
|
||||
"test123@example.com",
|
||||
"e2e.qa.auth@example.com",
|
||||
"e2e.qa.builder@example.com",
|
||||
"e2e.qa.library@example.com",
|
||||
"e2e.qa.marketplace@example.com",
|
||||
"e2e.qa.settings@example.com",
|
||||
"e2e.qa.parallel.a@example.com",
|
||||
"e2e.qa.parallel.b@example.com",
|
||||
]
|
||||
|
||||
|
||||
def get_image():
|
||||
@@ -100,6 +131,25 @@ def get_category():
|
||||
return random.choice(categories)
|
||||
|
||||
|
||||
def load_deterministic_marketplace_graph() -> Graph:
|
||||
graph = Graph.model_validate(
|
||||
json.loads(E2E_MARKETPLACE_AGENT_TEMPLATE_PATH.read_text())
|
||||
)
|
||||
graph.name = E2E_MARKETPLACE_AGENT_NAME
|
||||
graph.description = (
|
||||
"Deterministic marketplace calculator graph for Playwright PR E2E coverage."
|
||||
)
|
||||
|
||||
for node in graph.nodes:
|
||||
if (
|
||||
node.block_id == AgentInputBlock().id
|
||||
and node.input_default.get("value") is None
|
||||
):
|
||||
node.input_default["value"] = E2E_MARKETPLACE_AGENT_INPUT_VALUE
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
class TestDataCreator:
|
||||
"""Creates test data using API functions for E2E tests."""
|
||||
|
||||
@@ -123,9 +173,9 @@ class TestDataCreator:
|
||||
for i in range(NUM_USERS):
|
||||
try:
|
||||
# Generate test user data
|
||||
if i == 0:
|
||||
# First user should have test123@gmail.com email for testing
|
||||
email = "test123@gmail.com"
|
||||
if i < len(SEEDED_TEST_EMAILS):
|
||||
# Keep a deterministic pool for Playwright global setup and PR smoke flows
|
||||
email = SEEDED_TEST_EMAILS[i]
|
||||
else:
|
||||
email = faker.unique.email()
|
||||
password = "testpassword123" # Standard test password # pragma: allowlist secret # noqa
|
||||
@@ -547,6 +597,46 @@ class TestDataCreator:
|
||||
print(f"Error updating profile {profile.id}: {e}")
|
||||
continue
|
||||
|
||||
deterministic_creator = next(
|
||||
(
|
||||
user
|
||||
for user in self.users
|
||||
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
|
||||
),
|
||||
None,
|
||||
)
|
||||
if deterministic_creator:
|
||||
deterministic_profile = next(
|
||||
(
|
||||
profile
|
||||
for profile in existing_profiles
|
||||
if profile.userId == deterministic_creator["id"]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if deterministic_profile:
|
||||
try:
|
||||
updated_profile = await prisma.profile.update(
|
||||
where={"id": deterministic_profile.id},
|
||||
data={
|
||||
"name": "E2E Marketplace Creator",
|
||||
"username": E2E_MARKETPLACE_CREATOR_USERNAME,
|
||||
"description": "Deterministic marketplace creator for Playwright PR E2E coverage.",
|
||||
"links": ["https://example.com/e2e-marketplace"],
|
||||
"avatarUrl": get_image(),
|
||||
"isFeatured": True,
|
||||
},
|
||||
)
|
||||
profiles = [
|
||||
profile
|
||||
for profile in profiles
|
||||
if profile.get("id") != deterministic_profile.id
|
||||
]
|
||||
if updated_profile is not None:
|
||||
profiles.append(updated_profile.model_dump())
|
||||
except Exception as e:
|
||||
print(f"Error updating deterministic E2E creator profile: {e}")
|
||||
|
||||
self.profiles = profiles
|
||||
return profiles
|
||||
|
||||
@@ -562,58 +652,184 @@ class TestDataCreator:
|
||||
featured_count = 0
|
||||
submission_counter = 0
|
||||
|
||||
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||
# Create a deterministic calculator marketplace agent for PR E2E coverage
|
||||
test_user = next(
|
||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||
(
|
||||
user
|
||||
for user in self.users
|
||||
if user["email"] == E2E_MARKETPLACE_CREATOR_EMAIL
|
||||
),
|
||||
None,
|
||||
)
|
||||
if test_user and self.agent_graphs:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"graph_id": self.agent_graphs[0]["id"],
|
||||
"graph_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
"sub_heading": "A test agent for frontend testing",
|
||||
"video_url": "https://www.youtube.com/watch?v=test123",
|
||||
"image_urls": [
|
||||
"https://picsum.photos/200/300",
|
||||
"https://picsum.photos/200/301",
|
||||
"https://picsum.photos/200/302",
|
||||
],
|
||||
"description": "This is a test agent submission specifically created for frontend testing purposes.",
|
||||
"categories": ["test", "demo", "frontend"],
|
||||
"changes_summary": "Initial test submission",
|
||||
}
|
||||
if test_user:
|
||||
deterministic_graph = None
|
||||
|
||||
try:
|
||||
test_submission = await create_store_submission(**test_submission_data)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
if test_submission.listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
existing_graph = await prisma_models.AgentGraph.prisma().find_first(
|
||||
where={
|
||||
"userId": test_user["id"],
|
||||
"name": E2E_MARKETPLACE_AGENT_NAME,
|
||||
"isActive": True,
|
||||
},
|
||||
order={"version": "desc"},
|
||||
)
|
||||
if existing_graph:
|
||||
deterministic_graph = {
|
||||
"id": existing_graph.id,
|
||||
"version": existing_graph.version,
|
||||
"name": existing_graph.name,
|
||||
"userId": test_user["id"],
|
||||
}
|
||||
self.agent_graphs.append(deterministic_graph)
|
||||
print(
|
||||
"✅ Reused existing deterministic marketplace graph: "
|
||||
f"{existing_graph.id}"
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
else:
|
||||
deterministic_graph_model = make_graph_model(
|
||||
load_deterministic_marketplace_graph(),
|
||||
test_user["id"],
|
||||
)
|
||||
featured_count += 1
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
|
||||
deterministic_graph_model.reassign_ids(
|
||||
user_id=test_user["id"],
|
||||
reassign_graph_id=True,
|
||||
)
|
||||
created_deterministic_graph = await create_graph(
|
||||
deterministic_graph_model,
|
||||
test_user["id"],
|
||||
)
|
||||
deterministic_graph = created_deterministic_graph.model_dump()
|
||||
deterministic_graph["userId"] = test_user["id"]
|
||||
self.agent_graphs.append(deterministic_graph)
|
||||
print("✅ Created deterministic marketplace graph")
|
||||
except Exception as e:
|
||||
print(f"Error creating test store submission: {e}")
|
||||
import traceback
|
||||
print(f"Error creating deterministic marketplace graph: {e}")
|
||||
|
||||
traceback.print_exc()
|
||||
if deterministic_graph is None and self.agent_graphs:
|
||||
test_user_graphs = [
|
||||
graph
|
||||
for graph in self.agent_graphs
|
||||
if graph.get("userId") == test_user["id"]
|
||||
]
|
||||
deterministic_graph = next(
|
||||
(
|
||||
graph
|
||||
for graph in test_user_graphs
|
||||
if not graph.get("name", "").startswith("DummyInput ")
|
||||
),
|
||||
test_user_graphs[0] if test_user_graphs else None,
|
||||
)
|
||||
|
||||
if deterministic_graph:
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"graph_id": deterministic_graph["id"],
|
||||
"graph_version": deterministic_graph.get("version", 1),
|
||||
"slug": E2E_MARKETPLACE_AGENT_SLUG,
|
||||
"name": E2E_MARKETPLACE_AGENT_NAME,
|
||||
"sub_heading": "A deterministic calculator agent for PR E2E coverage",
|
||||
"video_url": "https://www.youtube.com/watch?v=test123",
|
||||
"image_urls": [
|
||||
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
|
||||
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
|
||||
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
|
||||
],
|
||||
"description": (
|
||||
"A deterministic marketplace calculator agent that adds "
|
||||
f"{E2E_MARKETPLACE_AGENT_INPUT_VALUE} and 34 to produce "
|
||||
f"{E2E_MARKETPLACE_AGENT_OUTPUT_VALUE} for frontend E2E coverage."
|
||||
),
|
||||
"categories": ["test", "demo", "frontend"],
|
||||
"changes_summary": (
|
||||
"Initial deterministic calculator submission seeded from "
|
||||
"backend/agents/calculator-agent.json"
|
||||
),
|
||||
}
|
||||
|
||||
try:
|
||||
existing_deterministic_submission = (
|
||||
await prisma_models.StoreListingVersion.prisma().find_first(
|
||||
where={
|
||||
"isDeleted": False,
|
||||
"StoreListing": {
|
||||
"is": {
|
||||
"owningUserId": test_user["id"],
|
||||
"slug": E2E_MARKETPLACE_AGENT_SLUG,
|
||||
"isDeleted": False,
|
||||
}
|
||||
},
|
||||
},
|
||||
include={"StoreListing": True},
|
||||
order={"version": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
if existing_deterministic_submission:
|
||||
test_submission = StoreSubmission.from_listing_version(
|
||||
existing_deterministic_submission
|
||||
)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Reused deterministic marketplace submission: "
|
||||
f"{E2E_MARKETPLACE_AGENT_NAME}"
|
||||
)
|
||||
else:
|
||||
test_submission = await create_store_submission(
|
||||
**test_submission_data
|
||||
)
|
||||
submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Created deterministic marketplace submission: "
|
||||
f"{E2E_MARKETPLACE_AGENT_NAME}"
|
||||
)
|
||||
|
||||
current_status = (
|
||||
existing_deterministic_submission.submissionStatus
|
||||
if existing_deterministic_submission
|
||||
else test_submission.status
|
||||
)
|
||||
is_featured = bool(
|
||||
existing_deterministic_submission
|
||||
and existing_deterministic_submission.isFeatured
|
||||
)
|
||||
|
||||
if test_submission.listing_version_id:
|
||||
if current_status != prisma_enums.SubmissionStatus.APPROVED:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Deterministic calculator submission approved",
|
||||
internal_comments="Auto-approved PR E2E marketplace submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(
|
||||
approved_submission.model_dump()
|
||||
)
|
||||
print("✅ Approved deterministic marketplace submission")
|
||||
else:
|
||||
approved_submissions.append(test_submission.model_dump())
|
||||
print(
|
||||
"✅ Deterministic marketplace submission already approved"
|
||||
)
|
||||
|
||||
if is_featured:
|
||||
featured_count += 1
|
||||
print("🌟 Deterministic marketplace agent already FEATURED")
|
||||
else:
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
"🌟 Marked deterministic marketplace agent as FEATURED"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating deterministic marketplace submission: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Create regular submissions for all users
|
||||
for user in self.users:
|
||||
|
||||
@@ -6,7 +6,8 @@
|
||||
# 5. CLI arguments - docker compose run -e VAR=value
|
||||
|
||||
# Common backend environment - Docker service names
|
||||
x-backend-env: &backend-env # Docker internal service hostnames (override localhost defaults)
|
||||
x-backend-env:
|
||||
&backend-env # Docker internal service hostnames (override localhost defaults)
|
||||
PYRO_HOST: "0.0.0.0"
|
||||
AGENTSERVER_HOST: rest_server
|
||||
SCHEDULER_HOST: scheduler_server
|
||||
@@ -39,7 +40,12 @@ services:
|
||||
context: ../
|
||||
dockerfile: autogpt_platform/backend/Dockerfile
|
||||
target: migrate
|
||||
command: ["sh", "-c", "prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy"]
|
||||
command:
|
||||
[
|
||||
"sh",
|
||||
"-c",
|
||||
"prisma generate && python3 scripts/gen_prisma_types_stub.py && prisma migrate deploy",
|
||||
]
|
||||
develop:
|
||||
watch:
|
||||
- path: ./
|
||||
@@ -79,8 +85,8 @@ services:
|
||||
falkordb:
|
||||
image: falkordb/falkordb:latest
|
||||
ports:
|
||||
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
|
||||
- "3001:3000" # FalkorDB web UI
|
||||
- "6380:6379" # FalkorDB Redis protocol (6380 to avoid clash with Redis on 6379)
|
||||
- "3001:3000" # FalkorDB web UI
|
||||
environment:
|
||||
- REDIS_ARGS=--requirepass ${GRAPHITI_FALKORDB_PASSWORD:-}
|
||||
volumes:
|
||||
@@ -88,7 +94,11 @@ services:
|
||||
networks:
|
||||
- app-network
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "redis-cli -p 6379 -a \"${GRAPHITI_FALKORDB_PASSWORD:-}\" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1"]
|
||||
test:
|
||||
[
|
||||
"CMD-SHELL",
|
||||
'redis-cli -p 6379 -a "${GRAPHITI_FALKORDB_PASSWORD:-}" --no-auth-warning ping && wget --spider -q http://localhost:3000 || exit 1',
|
||||
]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
@@ -300,19 +310,6 @@ services:
|
||||
condition: service_completed_successfully
|
||||
database_manager:
|
||||
condition: service_started
|
||||
# healthcheck:
|
||||
# test:
|
||||
# [
|
||||
# "CMD",
|
||||
# "curl",
|
||||
# "-f",
|
||||
# "-X",
|
||||
# "POST",
|
||||
# "http://localhost:8003/health_check",
|
||||
# ]
|
||||
# interval: 10s
|
||||
# timeout: 10s
|
||||
# retries: 5
|
||||
<<: *backend-env-files
|
||||
environment:
|
||||
<<: *backend-env
|
||||
|
||||
@@ -193,3 +193,4 @@ services:
|
||||
- copilot_executor
|
||||
- websocket_server
|
||||
- database_manager
|
||||
- scheduler_server
|
||||
|
||||
@@ -8,6 +8,7 @@ const config: StorybookConfig = {
|
||||
"../src/components/molecules/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/ai-elements/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/components/renderers/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
"../src/app/[(]platform[)]/copilot/**/*.stories.@(js|jsx|mjs|ts|tsx)",
|
||||
],
|
||||
addons: [
|
||||
"@storybook/addon-a11y",
|
||||
|
||||
@@ -81,8 +81,10 @@ Every time a new Front-end dependency is added by you or others, you will need t
|
||||
- `pnpm lint` - Run ESLint and Prettier checks
|
||||
- `pnpm format` - Format code with Prettier
|
||||
- `pnpm types` - Run TypeScript type checking
|
||||
- `pnpm test` - Run Playwright tests
|
||||
- `pnpm test-ui` - Run Playwright tests with UI
|
||||
- `pnpm test:unit` - Run the Vitest integration and unit suite with coverage
|
||||
- `pnpm test` - Run the Playwright E2E suite used in CI
|
||||
- `pnpm test-ui` - Run the same Playwright E2E suite with UI
|
||||
- `pnpm test:e2e:no-build` - Run the same Playwright E2E suite against a running app
|
||||
- `pnpm fetch:openapi` - Fetch OpenAPI spec from backend
|
||||
- `pnpm generate:api-client` - Generate API client from OpenAPI spec
|
||||
- `pnpm generate:api` - Fetch OpenAPI spec and generate API client
|
||||
|
||||
@@ -121,35 +121,49 @@ Only when the component has complex internal logic that is hard to exercise thro
|
||||
### Running
|
||||
|
||||
```bash
|
||||
pnpm test # build + run all Playwright tests
|
||||
pnpm test-ui # run with Playwright UI
|
||||
pnpm test:no-build # run against a running dev server
|
||||
pnpm test # build + run the Playwright E2E suite used in CI
|
||||
pnpm test-ui # run the same E2E suite with Playwright UI
|
||||
pnpm test:e2e:no-build # run the same E2E suite against a running dev server
|
||||
pnpm exec playwright test # run the same eight-spec Playwright suite directly
|
||||
```
|
||||
|
||||
### Setup
|
||||
|
||||
1. Start the backend + Supabase stack:
|
||||
- From `autogpt_platform`: `docker compose --profile local up deps_backend -d`
|
||||
2. Seed rich E2E data (creates `test123@gmail.com` with library agents):
|
||||
2. Seed rich E2E data (creates `test123@example.com` with library agents):
|
||||
- From `autogpt_platform/backend`: `poetry run python test/e2e_test_data.py`
|
||||
|
||||
### How Playwright setup works
|
||||
|
||||
- Playwright runs from `frontend/playwright.config.ts` with a global setup step
|
||||
- Global setup creates a user pool via the real signup UI, stored in `frontend/.auth/user-pool.json`
|
||||
- `getTestUser()` (from `src/tests/utils/auth.ts`) pulls a random user from the pool
|
||||
- Playwright runs from `frontend/playwright.config.ts` and keeps browser-only code in `frontend/src/playwright/`
|
||||
- Global setup creates reusable auth states for deterministic seeded accounts in `frontend/.auth/states/`
|
||||
- `getTestUser()` (from `src/playwright/utils/auth.ts`) picks one seeded account for general auth coverage
|
||||
- `getTestUserWithLibraryAgents()` uses the rich user created by the data script
|
||||
|
||||
### Test users
|
||||
|
||||
- **User pool (basic users)** — created automatically by Playwright global setup. Used by `getTestUser()`
|
||||
- **Seeded E2E accounts** — created by backend fixtures and logged in during Playwright global setup. Used by `getTestUser()` and `E2E_AUTH_STATES`
|
||||
- **Rich user with library agents** — created by `backend/test/e2e_test_data.py`. Used by `getTestUserWithLibraryAgents()`
|
||||
|
||||
### Current Playwright E2E suite
|
||||
|
||||
The CI suite is intentionally limited to the cross-page journeys we still require a real browser for. Playwright discovers the PR-gating specs by the `*-happy-path.spec.ts` naming pattern inside `src/playwright/`:
|
||||
|
||||
- `src/playwright/auth-happy-path.spec.ts`
|
||||
- `src/playwright/settings-happy-path.spec.ts`
|
||||
- `src/playwright/api-keys-happy-path.spec.ts`
|
||||
- `src/playwright/builder-happy-path.spec.ts`
|
||||
- `src/playwright/library-happy-path.spec.ts`
|
||||
- `src/playwright/marketplace-happy-path.spec.ts`
|
||||
- `src/playwright/publish-happy-path.spec.ts`
|
||||
- `src/playwright/copilot-happy-path.spec.ts`
|
||||
|
||||
### Resetting the DB
|
||||
|
||||
If you reset the Docker DB and logins start failing:
|
||||
|
||||
1. Delete `frontend/.auth/user-pool.json`
|
||||
1. Delete `frontend/.auth/states/*` and `frontend/.auth/user-pool.json` if it exists
|
||||
2. Re-run `poetry run python test/e2e_test_data.py`
|
||||
|
||||
## Storybook
|
||||
|
||||
@@ -13,11 +13,13 @@
|
||||
"lint": "next lint && prettier --check .",
|
||||
"format": "next lint --fix; prettier --write .",
|
||||
"types": "tsc --noEmit",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && playwright test --ui",
|
||||
"test": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
|
||||
"test-ui": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:ui",
|
||||
"test:unit": "vitest run --coverage",
|
||||
"test:unit:watch": "vitest",
|
||||
"test:no-build": "playwright test",
|
||||
"test:e2e": "NEXT_PUBLIC_PW_TEST=true next build --turbo && pnpm test:e2e:no-build",
|
||||
"test:e2e:no-build": "playwright test",
|
||||
"test:e2e:ui": "playwright test --ui",
|
||||
"gentests": "playwright codegen http://localhost:3000",
|
||||
"storybook": "storybook dev -p 6006",
|
||||
"build-storybook": "storybook build",
|
||||
|
||||
@@ -7,10 +7,22 @@ import { defineConfig, devices } from "@playwright/test";
|
||||
import dotenv from "dotenv";
|
||||
import fs from "fs";
|
||||
import path from "path";
|
||||
import { buildCookieConsentStorageState } from "./src/playwright/credentials/storage-state";
|
||||
dotenv.config({ path: path.resolve(__dirname, ".env") });
|
||||
dotenv.config({ path: path.resolve(__dirname, "../backend/.env") });
|
||||
|
||||
const frontendRoot = __dirname.replaceAll("\\", "/");
|
||||
const configuredBaseURL =
|
||||
process.env.PLAYWRIGHT_BASE_URL ?? "http://localhost:3000";
|
||||
const parsedBaseURL = new URL(configuredBaseURL);
|
||||
const baseURL = parsedBaseURL.toString().replace(/\/$/, "");
|
||||
const baseOrigin = parsedBaseURL.origin;
|
||||
const jsonReporterOutputFile = process.env.PLAYWRIGHT_JSON_OUTPUT_FILE;
|
||||
const configuredWorkers = process.env.PLAYWRIGHT_WORKERS
|
||||
? Number(process.env.PLAYWRIGHT_WORKERS)
|
||||
: process.env.CI
|
||||
? 8
|
||||
: undefined;
|
||||
|
||||
// Directory where CI copies .next/static from the Docker container
|
||||
const staticCoverageDir = path.resolve(__dirname, ".next-static-coverage");
|
||||
@@ -57,17 +69,18 @@ function resolveSourceMap(sourcePath: string) {
|
||||
}
|
||||
|
||||
export default defineConfig({
|
||||
testDir: "./src/tests",
|
||||
testDir: "./src/playwright",
|
||||
testMatch: /.*-happy-path\.spec\.ts/,
|
||||
/* Global setup file that runs before all tests */
|
||||
globalSetup: "./src/tests/global-setup.ts",
|
||||
globalSetup: "./src/playwright/global-setup.ts",
|
||||
/* Run tests in files in parallel */
|
||||
fullyParallel: true,
|
||||
/* Fail the build on CI if you accidentally left test.only in the source code. */
|
||||
forbidOnly: !!process.env.CI,
|
||||
/* Retry on CI only */
|
||||
retries: process.env.CI ? 1 : 0,
|
||||
/* use more workers on CI. */
|
||||
workers: process.env.CI ? 4 : undefined,
|
||||
retries: process.env.CI ? Number(process.env.PLAYWRIGHT_RETRIES ?? 2) : 0,
|
||||
/* Higher worker count keeps PR smoke runtime down without sharing page state. */
|
||||
workers: configuredWorkers,
|
||||
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
|
||||
reporter: [
|
||||
["list"],
|
||||
@@ -92,40 +105,25 @@ export default defineConfig({
|
||||
},
|
||||
},
|
||||
],
|
||||
...(jsonReporterOutputFile
|
||||
? [["json", { outputFile: jsonReporterOutputFile }] as const]
|
||||
: []),
|
||||
],
|
||||
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
|
||||
use: {
|
||||
/* Base URL to use in actions like `await page.goto('/')`. */
|
||||
baseURL: "http://localhost:3000/",
|
||||
baseURL,
|
||||
|
||||
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
|
||||
screenshot: "only-on-failure",
|
||||
bypassCSP: true,
|
||||
|
||||
/* Helps debugging failures */
|
||||
trace: "retain-on-failure",
|
||||
video: "retain-on-failure",
|
||||
trace: process.env.CI ? "on-first-retry" : "retain-on-failure",
|
||||
video: process.env.CI ? "off" : "retain-on-failure",
|
||||
|
||||
/* Auto-accept cookies in all tests to prevent banner interference */
|
||||
storageState: {
|
||||
cookies: [],
|
||||
origins: [
|
||||
{
|
||||
origin: "http://localhost:3000",
|
||||
localStorage: [
|
||||
{
|
||||
name: "autogpt_cookie_consent",
|
||||
value: JSON.stringify({
|
||||
hasConsented: true,
|
||||
timestamp: Date.now(),
|
||||
analytics: true,
|
||||
monitoring: true,
|
||||
}),
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
storageState: buildCookieConsentStorageState(baseOrigin),
|
||||
},
|
||||
/* Maximum time one test can run for */
|
||||
timeout: 25000,
|
||||
@@ -133,7 +131,7 @@ export default defineConfig({
|
||||
/* Configure web server to start automatically (local dev only) */
|
||||
webServer: {
|
||||
command: "pnpm start",
|
||||
url: "http://localhost:3000",
|
||||
url: baseURL,
|
||||
reuseExistingServer: true,
|
||||
},
|
||||
|
||||
|
||||
@@ -29,6 +29,16 @@ const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
avg_input_tokens_per_request: 0,
|
||||
avg_output_tokens_per_request: 0,
|
||||
avg_cost_microdollars_per_request: 0,
|
||||
cost_p50_microdollars: 0,
|
||||
cost_p75_microdollars: 0,
|
||||
cost_p95_microdollars: 0,
|
||||
cost_p99_microdollars: 0,
|
||||
cost_buckets: [],
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
@@ -47,6 +57,20 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
total_input_tokens: 150000,
|
||||
total_output_tokens: 60000,
|
||||
avg_input_tokens_per_request: 2500,
|
||||
avg_output_tokens_per_request: 1000,
|
||||
avg_cost_microdollars_per_request: 83333,
|
||||
cost_p50_microdollars: 50000,
|
||||
cost_p75_microdollars: 100000,
|
||||
cost_p95_microdollars: 250000,
|
||||
cost_p99_microdollars: 500000,
|
||||
cost_buckets: [
|
||||
{ bucket: "$0-0.50", count: 80 },
|
||||
{ bucket: "$0.50-1", count: 15 },
|
||||
{ bucket: "$1-2", count: 5 },
|
||||
],
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
@@ -75,6 +99,7 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
cost_bearing_request_count: 40,
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -134,9 +159,14 @@ describe("PlatformCostContent", () => {
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Verify the two summary cards that show $0.0000 — Known Cost and Estimated Total
|
||||
// Known Cost and Estimated Total cards render $0.0000
|
||||
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
|
||||
// Typical/Upper/High/Peak Cost) show $0.0000
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(2);
|
||||
expect(zeroCostItems.length).toBe(7);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
});
|
||||
|
||||
@@ -155,7 +185,9 @@ describe("PlatformCostContent", () => {
|
||||
);
|
||||
expect(screen.getByText("$5.0000")).toBeDefined();
|
||||
expect(screen.getByText("100")).toBeDefined();
|
||||
expect(screen.getByText("5")).toBeDefined();
|
||||
// "5" appears in multiple places (Active Users card + bucket count),
|
||||
// so verify at least one element renders it.
|
||||
expect(screen.getAllByText("5").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("openai")).toBeDefined();
|
||||
expect(screen.getByText("google_maps")).toBeDefined();
|
||||
});
|
||||
@@ -223,10 +255,83 @@ describe("PlatformCostContent", () => {
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Original 4 cards
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
// New average/token cards
|
||||
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
|
||||
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Total Tokens")).toBeDefined();
|
||||
// Percentile cards (friendlier labels)
|
||||
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
|
||||
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
|
||||
expect(screen.getByText("High Cost (P95)")).toBeDefined();
|
||||
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders cost distribution buckets", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
|
||||
expect(screen.getByText("$0-0.50")).toBeDefined();
|
||||
expect(screen.getByText("$0.50-1")).toBeDefined();
|
||||
expect(screen.getByText("$1-2")).toBeDefined();
|
||||
expect(screen.getByText("80")).toBeDefined();
|
||||
expect(screen.getByText("15")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders new summary card values from fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Avg Input Tokens: 2500 formatted
|
||||
expect(screen.getByText("2,500")).toBeDefined();
|
||||
// Avg Output Tokens: 1000 formatted
|
||||
expect(screen.getByText("1,000")).toBeDefined();
|
||||
// P50 cost: 50000 microdollars = $0.0500
|
||||
expect(screen.getByText("$0.0500")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table avg cost column with fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// User table should show Avg Cost / Req header
|
||||
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
|
||||
// Input/Output token columns
|
||||
expect(screen.getByText("Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Output Tokens")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
|
||||
import { Skeleton } from "@/components/atoms/Skeleton/Skeleton";
|
||||
import { formatMicrodollars } from "../helpers";
|
||||
import { formatMicrodollars, formatTokens } from "../helpers";
|
||||
import { SummaryCard } from "./SummaryCard";
|
||||
import { ProviderTable } from "./ProviderTable";
|
||||
import { UserTable } from "./UserTable";
|
||||
import { LogsTable } from "./LogsTable";
|
||||
import { usePlatformCostContent } from "./usePlatformCostContent";
|
||||
import type { CostBucket } from "@/app/api/__generated__/models/costBucket";
|
||||
|
||||
interface Props {
|
||||
searchParams: {
|
||||
@@ -54,6 +55,76 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
handleExport,
|
||||
} = usePlatformCostContent(searchParams);
|
||||
|
||||
const summaryCards: { label: string; value: string; subtitle?: string }[] =
|
||||
dashboard
|
||||
? [
|
||||
{
|
||||
label: "Known Cost",
|
||||
value: formatMicrodollars(dashboard.total_cost_microdollars),
|
||||
subtitle: "From providers that report USD cost",
|
||||
},
|
||||
{
|
||||
label: "Estimated Total",
|
||||
value: formatMicrodollars(totalEstimatedCost),
|
||||
subtitle: "Including per-run cost estimates",
|
||||
},
|
||||
{
|
||||
label: "Total Requests",
|
||||
value: dashboard.total_requests.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Active Users",
|
||||
value: dashboard.total_users.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Avg Cost / Request",
|
||||
value: formatMicrodollars(
|
||||
dashboard.avg_cost_microdollars_per_request ?? 0,
|
||||
),
|
||||
subtitle: "Known cost divided by cost-bearing requests",
|
||||
},
|
||||
{
|
||||
label: "Avg Input Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_input_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle: "Prompt tokens per request (context size)",
|
||||
},
|
||||
{
|
||||
label: "Avg Output Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_output_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle: "Completion tokens per request (response length)",
|
||||
},
|
||||
{
|
||||
label: "Total Tokens",
|
||||
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
|
||||
subtitle: "Prompt vs completion token split",
|
||||
},
|
||||
{
|
||||
label: "Typical Cost (P50)",
|
||||
value: formatMicrodollars(dashboard.cost_p50_microdollars ?? 0),
|
||||
subtitle: "Median cost per request",
|
||||
},
|
||||
{
|
||||
label: "Upper Cost (P75)",
|
||||
value: formatMicrodollars(dashboard.cost_p75_microdollars ?? 0),
|
||||
subtitle: "75th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "High Cost (P95)",
|
||||
value: formatMicrodollars(dashboard.cost_p95_microdollars ?? 0),
|
||||
subtitle: "95th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "Peak Cost (P99)",
|
||||
value: formatMicrodollars(dashboard.cost_p99_microdollars ?? 0),
|
||||
subtitle: "99th percentile cost",
|
||||
},
|
||||
]
|
||||
: [];
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div className="flex flex-wrap items-end gap-3 rounded-lg border p-4">
|
||||
@@ -204,37 +275,54 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(4)].map((_, i) => (
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
{/* 12 skeleton placeholders — one per summary card */}
|
||||
{Array.from({ length: 12 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
<Skeleton className="h-32 rounded-lg" />
|
||||
<Skeleton className="h-8 w-48 rounded" />
|
||||
<Skeleton className="h-64 rounded-lg" />
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
{dashboard && (
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
</div>
|
||||
<>
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
{summaryCards.map((card) => (
|
||||
<SummaryCard
|
||||
key={card.label}
|
||||
label={card.label}
|
||||
value={card.value}
|
||||
subtitle={card.subtitle}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
|
||||
<div className="rounded-lg border p-4">
|
||||
<h3 className="mb-3 text-sm font-medium">
|
||||
Cost Distribution by Bucket
|
||||
</h3>
|
||||
<div className="grid grid-cols-2 gap-2 sm:grid-cols-3 md:grid-cols-6">
|
||||
{dashboard.cost_buckets.map((b: CostBucket) => (
|
||||
<div
|
||||
key={b.bucket}
|
||||
className="flex flex-col items-center rounded border p-2 text-center"
|
||||
>
|
||||
<span className="text-xs text-muted-foreground">
|
||||
{b.bucket}
|
||||
</span>
|
||||
<span className="text-lg font-semibold">
|
||||
{b.count.toLocaleString()}
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<div
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
defaultRateFor,
|
||||
estimateCostForRow,
|
||||
formatMicrodollars,
|
||||
formatTokens,
|
||||
rateKey,
|
||||
rateUnitLabel,
|
||||
trackingValue,
|
||||
@@ -33,6 +34,20 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Usage
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
|
||||
>
|
||||
Input Tokens
|
||||
</th>
|
||||
<th
|
||||
scope="col"
|
||||
className="px-4 py-3 text-right"
|
||||
title="Only populated for token-tracking providers (e.g. LLM calls). Non-token rows (per_run, characters, etc.) show —."
|
||||
>
|
||||
Output Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Requests
|
||||
</th>
|
||||
@@ -74,6 +89,16 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
<TrackingBadge trackingType={row.tracking_type} />
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">{trackingValue(row)}</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_input_tokens > 0
|
||||
? formatTokens(row.total_input_tokens)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.total_output_tokens > 0
|
||||
? formatTokens(row.total_output_tokens)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count.toLocaleString()}
|
||||
</td>
|
||||
@@ -124,7 +149,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={8}
|
||||
colSpan={10}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
|
||||
@@ -27,10 +27,7 @@ function UserTable({ data }: Props) {
|
||||
Output Tokens
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Cache Read
|
||||
</th>
|
||||
<th scope="col" className="px-4 py-3 text-right">
|
||||
Cache Write
|
||||
Avg Cost / Req
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
@@ -61,13 +58,12 @@ function UserTable({ data }: Props) {
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{(row.total_cache_read_tokens ?? 0) > 0
|
||||
? formatTokens(row.total_cache_read_tokens ?? 0)
|
||||
: "-"}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{(row.total_cache_creation_tokens ?? 0) > 0
|
||||
? formatTokens(row.total_cache_creation_tokens ?? 0)
|
||||
{(row.cost_bearing_request_count ?? 0) > 0 &&
|
||||
row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(
|
||||
row.total_cost_microdollars /
|
||||
(row.cost_bearing_request_count ?? 1),
|
||||
)
|
||||
: "-"}
|
||||
</td>
|
||||
</tr>
|
||||
@@ -75,7 +71,7 @@ function UserTable({ data }: Props) {
|
||||
{data.length === 0 && (
|
||||
<tr>
|
||||
<td
|
||||
colSpan={7}
|
||||
colSpan={6}
|
||||
className="px-4 py-8 text-center text-muted-foreground"
|
||||
>
|
||||
No cost data yet
|
||||
|
||||
@@ -0,0 +1,145 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { ArtifactCard } from "./ArtifactCard";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactCard> = {
|
||||
title: "Copilot/ArtifactCard",
|
||||
component: ArtifactCard,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "padded",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Inline artifact card rendered in chat messages. Openable artifacts show a caret and open the ArtifactPanel on click. Download-only artifacts trigger a file download.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div className="w-96">
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const OpenableHTML: Story = {
|
||||
name: "Openable (HTML)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "dashboard.html",
|
||||
mimeType: "text/html",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenableImage: Story = {
|
||||
name: "Openable (Image)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-card",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenableCode: Story = {
|
||||
name: "Openable (Code)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "script.py",
|
||||
mimeType: "text/x-python",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const DownloadOnly: Story = {
|
||||
name: "Download Only (ZIP)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "archive.zip",
|
||||
mimeType: "application/zip",
|
||||
sizeBytes: 2_500_000,
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const PreviewableVideo: Story = {
|
||||
name: "Previewable (Video)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "demo.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sizeBytes: 15_000_000,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Videos with supported formats (MP4, WebM, M4V) are previewable inline in the artifact panel.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const WithSize: Story = {
|
||||
name: "With File Size",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
sizeBytes: 524_288,
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const UserUpload: Story = {
|
||||
name: "User Upload Origin",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
title: "requirements.txt",
|
||||
mimeType: "text/plain",
|
||||
origin: "user-upload",
|
||||
}),
|
||||
},
|
||||
};
|
||||
|
||||
export const ActiveState: Story = {
|
||||
name: "Active (Panel Open)",
|
||||
args: {
|
||||
artifact: makeArtifact({ id: "active-card" }),
|
||||
},
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: makeArtifact({ id: "active-card" }),
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -0,0 +1,223 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { ArtifactPanel } from "./ArtifactPanel";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
|
||||
const PROXY_BASE = "/api/proxy/api/workspace/files";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/file-001/download`,
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function openPanelWith(artifact: ArtifactRef) {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: artifact,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactPanel> = {
|
||||
title: "Copilot/ArtifactPanel",
|
||||
component: ArtifactPanel,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "fullscreen",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Side panel for previewing workspace artifacts. Supports resize, minimize, maximize, and navigation history. Bug: panel auto-opens on chat switch instead of staying collapsed.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div className="flex h-[600px] w-full">
|
||||
<div className="flex-1 bg-zinc-50 p-8">
|
||||
<p className="text-sm text-zinc-500">Chat area</p>
|
||||
</div>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const OpenWithTextArtifact: Story = {
|
||||
name: "Open — Text File",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({ title: "notes.txt", mimeType: "text/plain" }),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/file-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
"These are some notes from the agent execution.\n\nKey findings:\n1. Performance improved by 23%\n2. Memory usage reduced\n3. Error rate dropped to 0.1%",
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenWithHTMLArtifact: Story = {
|
||||
name: "Open — HTML",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "html-panel",
|
||||
title: "dashboard.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/html-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/html-panel/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`<!DOCTYPE html><html><body class="p-8 font-sans"><h1 class="text-2xl font-bold text-indigo-600">Dashboard</h1><p class="mt-2 text-gray-600">HTML artifact in the panel.</p></body></html>`,
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const OpenWithImageArtifact: Story = {
|
||||
name: "Open — Image (Bug: No Loading State)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "img-panel",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: `${PROXY_BASE}/img-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-panel/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="500" height="300"><rect width="500" height="300" fill="#dbeafe"/><text x="250" y="150" text-anchor="middle" fill="#1e40af" font-size="20">Image Preview (no skeleton)</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"**BUG:** Image artifacts render with a bare `<img>` tag — no loading skeleton or error handling. Compare with text/HTML artifacts which show a proper skeleton while loading.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const MinimizedStrip: Story = {
|
||||
name: "Minimized",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: true,
|
||||
isMinimized: true,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: makeArtifact(),
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
export const ErrorState: Story = {
|
||||
name: "Error — Failed to Load (Stale Artifact)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
openPanelWith(
|
||||
makeArtifact({
|
||||
id: "stale-panel",
|
||||
title: "old-report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/stale-panel/download`,
|
||||
}),
|
||||
);
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/stale-panel/download`, () => {
|
||||
return new HttpResponse(null, { status: 404 });
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows what users see when opening a previously generated artifact that no longer exists on the backend (404). The 'Try again' button retries the fetch.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const Closed: Story = {
|
||||
name: "Closed (Default State)",
|
||||
decorators: [
|
||||
(Story) => {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
return <Story />;
|
||||
},
|
||||
],
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"The default state — panel is closed. It should only open when a user clicks on an artifact card in the chat.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -0,0 +1,413 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import { downloadArtifact } from "../downloadArtifact";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("downloadArtifact", () => {
|
||||
let clickSpy: ReturnType<typeof vi.fn>;
|
||||
let removeSpy: ReturnType<typeof vi.fn>;
|
||||
|
||||
beforeEach(() => {
|
||||
clickSpy = vi.fn();
|
||||
removeSpy = vi.fn();
|
||||
|
||||
vi.stubGlobal(
|
||||
"URL",
|
||||
Object.assign(URL, {
|
||||
createObjectURL: vi.fn().mockReturnValue("blob:fake-url"),
|
||||
revokeObjectURL: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
vi.spyOn(document, "createElement").mockReturnValue({
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
} as unknown as HTMLAnchorElement);
|
||||
|
||||
vi.spyOn(document.body, "appendChild").mockImplementation(
|
||||
(node) => node as ChildNode,
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("downloads file successfully on 200 response", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["pdf content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
|
||||
expect(fetch).toHaveBeenCalledWith(
|
||||
"/api/proxy/api/workspace/files/file-001/download",
|
||||
);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
expect(removeSpy).toHaveBeenCalled();
|
||||
expect(URL.revokeObjectURL).toHaveBeenCalledWith("blob:fake-url");
|
||||
});
|
||||
|
||||
it("rejects on persistent server error after exhausting retries", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 500",
|
||||
);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("rejects on persistent network error after exhausting retries", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.reject(new Error("Network error"));
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Network error",
|
||||
);
|
||||
expect(callCount).toBe(3);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on transient network error and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.reject(new Error("Connection reset"));
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on transient 500 and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 500 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
// Should succeed on second attempt
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("sanitizes dangerous filenames", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "../../../etc/passwd" }));
|
||||
|
||||
expect(anchor.download).not.toContain("..");
|
||||
expect(anchor.download).not.toContain("/");
|
||||
});
|
||||
|
||||
// ── Transient retry codes ─────────────────────────────────────────
|
||||
|
||||
it("retries on 408 (Request Timeout) and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 408 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("retries on 429 (Too Many Requests) and succeeds", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({ ok: false, status: 429 });
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact());
|
||||
expect(callCount).toBe(2);
|
||||
expect(clickSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── Non-transient errors ──────────────────────────────────────────
|
||||
|
||||
it("rejects immediately on 403 (non-transient) without retry", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 403 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 403",
|
||||
);
|
||||
expect(callCount).toBe(1);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("rejects immediately on 404 without retry", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 404 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 404",
|
||||
);
|
||||
expect(callCount).toBe(1);
|
||||
});
|
||||
|
||||
// ── Exhausted retries ─────────────────────────────────────────────
|
||||
|
||||
it("rejects after exhausting all retries on persistent 500", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({ ok: false, status: 500 });
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(downloadArtifact(makeArtifact())).rejects.toThrow(
|
||||
"Download failed: 500",
|
||||
);
|
||||
// Initial attempt + 2 retries = 3 total
|
||||
expect(callCount).toBe(3);
|
||||
expect(clickSpy).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── Filename edge cases ───────────────────────────────────────────
|
||||
|
||||
it("falls back to 'download' when title is empty", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "" }));
|
||||
expect(anchor.download).toBe("download");
|
||||
});
|
||||
|
||||
it("falls back to 'download' when title is only dots", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
// Dot-only names should not produce a hidden or empty filename.
|
||||
await downloadArtifact(makeArtifact({ title: "...." }));
|
||||
expect(anchor.download).toBe("download");
|
||||
});
|
||||
|
||||
it("replaces special chars with underscores (not empty)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: '***???"' }));
|
||||
// Special chars become underscores, not removed
|
||||
expect(anchor.download).toBe("_______");
|
||||
});
|
||||
|
||||
it("strips leading dots from filename", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(makeArtifact({ title: "...hidden.txt" }));
|
||||
expect(anchor.download).not.toMatch(/^\./);
|
||||
expect(anchor.download).toContain("hidden.txt");
|
||||
});
|
||||
|
||||
it("replaces Windows-reserved characters", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(
|
||||
makeArtifact({ title: "file<name>with:bad*chars?.txt" }),
|
||||
);
|
||||
expect(anchor.download).not.toMatch(/[<>:*?]/);
|
||||
});
|
||||
|
||||
it("replaces control characters in filename", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
|
||||
const anchor = {
|
||||
href: "",
|
||||
download: "",
|
||||
click: clickSpy,
|
||||
remove: removeSpy,
|
||||
};
|
||||
vi.spyOn(document, "createElement").mockReturnValue(
|
||||
anchor as unknown as HTMLAnchorElement,
|
||||
);
|
||||
|
||||
await downloadArtifact(
|
||||
makeArtifact({ title: "file\x00with\x1fcontrol.txt" }),
|
||||
);
|
||||
expect(anchor.download).not.toMatch(/[\x00-\x1f]/);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,460 @@
|
||||
import type { Meta, StoryObj } from "@storybook/nextjs";
|
||||
import { http, HttpResponse } from "msw";
|
||||
import { ArtifactContent } from "./ArtifactContent";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import {
|
||||
Code,
|
||||
File,
|
||||
FileHtml,
|
||||
FileText,
|
||||
Image,
|
||||
Table,
|
||||
} from "@phosphor-icons/react";
|
||||
|
||||
const PROXY_BASE = "/api/proxy/api/workspace/files";
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "test.txt",
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: `${PROXY_BASE}/file-001/download`,
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeClassification(
|
||||
overrides?: Partial<ArtifactClassification>,
|
||||
): ArtifactClassification {
|
||||
return {
|
||||
type: "text",
|
||||
icon: FileText,
|
||||
label: "Text",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
const meta: Meta<typeof ArtifactContent> = {
|
||||
title: "Copilot/ArtifactContent",
|
||||
component: ArtifactContent,
|
||||
tags: ["autodocs"],
|
||||
parameters: {
|
||||
layout: "padded",
|
||||
docs: {
|
||||
description: {
|
||||
component:
|
||||
"Renders artifact content based on file type classification. Supports images, HTML, code, CSV, JSON, markdown, PDF, and plain text. Bug: image artifacts render as bare <img> with no loading/error states.",
|
||||
},
|
||||
},
|
||||
},
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<div
|
||||
className="flex h-[500px] w-[600px] flex-col overflow-hidden border border-zinc-200"
|
||||
style={{ resize: "both" }}
|
||||
>
|
||||
<Story />
|
||||
</div>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof meta>;
|
||||
|
||||
export const ImageArtifactPNG: Story = {
|
||||
name: "Image (PNG) — No Loading Skeleton (Bug #1)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-png",
|
||||
title: "chart.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: `${PROXY_BASE}/img-png/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({ type: "image", icon: Image }),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-png/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#e0e7ff"/><text x="200" y="150" text-anchor="middle" fill="#4338ca" font-size="24">PNG Placeholder</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"**BUG:** This renders a bare `<img>` tag with no loading skeleton or error handling. Compare with WorkspaceFileRenderer which has proper Skeleton + onError states.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const ImageArtifactSVG: Story = {
|
||||
name: "Image (SVG)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "img-svg",
|
||||
title: "diagram.svg",
|
||||
mimeType: "image/svg+xml",
|
||||
sourceUrl: `${PROXY_BASE}/img-svg/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({ type: "image", icon: Image }),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/img-svg/download`, () => {
|
||||
return HttpResponse.text(
|
||||
'<svg xmlns="http://www.w3.org/2000/svg" width="400" height="300"><rect width="400" height="300" fill="#fef3c7"/><circle cx="200" cy="150" r="80" fill="#f59e0b"/><text x="200" y="155" text-anchor="middle" fill="white" font-size="20">SVG OK</text></svg>',
|
||||
{ headers: { "Content-Type": "image/svg+xml" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const HTMLArtifact: Story = {
|
||||
name: "HTML",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "html-001",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/html-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/html-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`<!DOCTYPE html>
|
||||
<html>
|
||||
<head><title>Artifact Preview</title></head>
|
||||
<body class="p-8 font-sans">
|
||||
<h1 class="text-2xl font-bold text-indigo-600 mb-4">HTML Artifact</h1>
|
||||
<p class="text-gray-700">This is an HTML artifact rendered in a sandboxed iframe with Tailwind CSS injected.</p>
|
||||
<div class="mt-4 p-4 bg-blue-50 rounded-lg border border-blue-200">
|
||||
<p class="text-blue-800">Interactive content works via allow-scripts sandbox.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>`,
|
||||
{ headers: { "Content-Type": "text/html" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CodeArtifact: Story = {
|
||||
name: "Code (Python)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "code-001",
|
||||
title: "analysis.py",
|
||||
mimeType: "text/x-python",
|
||||
sourceUrl: `${PROXY_BASE}/code-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "code",
|
||||
icon: Code,
|
||||
label: "Code",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/code-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def analyze_data(filepath: str) -> pd.DataFrame:
|
||||
"""Load and analyze CSV data."""
|
||||
df = pd.read_csv(filepath)
|
||||
summary = df.describe()
|
||||
print(f"Loaded {len(df)} rows")
|
||||
return summary
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = analyze_data("data.csv")
|
||||
print(result)`,
|
||||
{ headers: { "Content-Type": "text/plain" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const CSVArtifact: Story = {
|
||||
name: "CSV (Spreadsheet)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "csv-001",
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
sourceUrl: `${PROXY_BASE}/csv-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "csv",
|
||||
icon: Table,
|
||||
label: "Spreadsheet",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/csv-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`Name,Age,City,Score
|
||||
Alice,28,New York,92
|
||||
Bob,35,San Francisco,87
|
||||
Charlie,22,Chicago,95
|
||||
Diana,31,Boston,88
|
||||
Eve,27,Seattle,91`,
|
||||
{ headers: { "Content-Type": "text/csv" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const JSONArtifact: Story = {
|
||||
name: "JSON (Data)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "json-001",
|
||||
title: "config.json",
|
||||
mimeType: "application/json",
|
||||
sourceUrl: `${PROXY_BASE}/json-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "json",
|
||||
icon: Code,
|
||||
label: "Data",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/json-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
JSON.stringify(
|
||||
{
|
||||
name: "AutoGPT Agent",
|
||||
version: "2.0",
|
||||
capabilities: ["web_search", "code_execution", "file_io"],
|
||||
settings: { maxTokens: 4096, temperature: 0.7 },
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
{ headers: { "Content-Type": "application/json" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const MarkdownArtifact: Story = {
|
||||
name: "Markdown",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "md-001",
|
||||
title: "README.md",
|
||||
mimeType: "text/markdown",
|
||||
sourceUrl: `${PROXY_BASE}/md-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "markdown",
|
||||
icon: FileText,
|
||||
label: "Document",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/md-001/download`, () => {
|
||||
return HttpResponse.text(
|
||||
`# Project Summary
|
||||
|
||||
## Overview
|
||||
This is a **markdown** artifact rendered through the global renderer registry.
|
||||
|
||||
## Features
|
||||
- Headings and paragraphs
|
||||
- **Bold** and *italic* text
|
||||
- Lists and code blocks
|
||||
|
||||
\`\`\`python
|
||||
print("Hello from markdown!")
|
||||
\`\`\`
|
||||
|
||||
> Blockquotes are also supported.`,
|
||||
{ headers: { "Content-Type": "text/plain" } },
|
||||
);
|
||||
}),
|
||||
],
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const PDFArtifact: Story = {
|
||||
name: "PDF",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "pdf-001",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: `${PROXY_BASE}/pdf-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "pdf",
|
||||
icon: FileText,
|
||||
label: "PDF",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/pdf-001/download`, () => {
|
||||
return HttpResponse.arrayBuffer(new ArrayBuffer(100), {
|
||||
headers: { "Content-Type": "application/pdf" },
|
||||
});
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"PDF artifacts are rendered in an unsandboxed iframe using a blob URL (Chromium bug #413851 prevents sandboxed PDF rendering).",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const ErrorState: Story = {
|
||||
name: "Error — Failed to Load Content",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "error-001",
|
||||
title: "old-report.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/error-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
hasSourceToggle: true,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/error-001/download`, () => {
|
||||
return new HttpResponse(null, { status: 404 });
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows the error state when an artifact fails to load (e.g., old/expired file returning 404). Includes a 'Try again' retry button.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const LoadingSkeleton: Story = {
|
||||
name: "Loading State",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "loading-001",
|
||||
title: "loading.html",
|
||||
mimeType: "text/html",
|
||||
sourceUrl: `${PROXY_BASE}/loading-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "html",
|
||||
icon: FileHtml,
|
||||
label: "HTML",
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
msw: {
|
||||
handlers: [
|
||||
http.get(`${PROXY_BASE}/loading-001/download`, async () => {
|
||||
// Delay response to show loading state
|
||||
await new Promise((r) => setTimeout(r, 999999));
|
||||
return HttpResponse.text("never resolves");
|
||||
}),
|
||||
],
|
||||
},
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Shows the skeleton loading state while content is being fetched.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const DownloadOnly: Story = {
|
||||
name: "Download Only (Binary)",
|
||||
args: {
|
||||
artifact: makeArtifact({
|
||||
id: "bin-001",
|
||||
title: "archive.zip",
|
||||
mimeType: "application/zip",
|
||||
sourceUrl: `${PROXY_BASE}/bin-001/download`,
|
||||
}),
|
||||
isSourceView: false,
|
||||
classification: makeClassification({
|
||||
type: "download-only",
|
||||
icon: File,
|
||||
label: "File",
|
||||
openable: false,
|
||||
}),
|
||||
},
|
||||
parameters: {
|
||||
docs: {
|
||||
description: {
|
||||
story:
|
||||
"Download-only files (binary, video, etc.) are not rendered inline. The ArtifactPanel shows nothing for these — they are handled by ArtifactCard with a download button.",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
|
||||
import { Suspense } from "react";
|
||||
import { Suspense, useState } from "react";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import type { ArtifactRef } from "../../../store";
|
||||
import type { ArtifactClassification } from "../helpers";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
@@ -63,6 +64,90 @@ function ArtifactContentLoader({
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactImage({ src, alt }: { src: string; alt: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm text-zinc-500">Failed to load image</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative flex items-center justify-center p-4">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
|
||||
)}
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={src}
|
||||
alt={alt}
|
||||
className={`max-h-full max-w-full object-contain transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
onLoad={() => setLoaded(true)}
|
||||
onError={() => setError(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactVideo({ src }: { src: string }) {
|
||||
const [loaded, setLoaded] = useState(false);
|
||||
const [error, setError] = useState(false);
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div
|
||||
role="alert"
|
||||
className="flex flex-col items-center justify-center gap-3 p-8 text-center"
|
||||
>
|
||||
<p className="text-sm text-zinc-500">Failed to load video</p>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => {
|
||||
setError(false);
|
||||
setLoaded(false);
|
||||
}}
|
||||
className="rounded-md border border-zinc-200 bg-white px-3 py-1.5 text-xs font-medium text-zinc-700 shadow-sm transition-colors hover:bg-zinc-50 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-violet-400"
|
||||
>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="relative flex items-center justify-center p-4">
|
||||
{!loaded && (
|
||||
<Skeleton className="absolute inset-4 h-[calc(100%-2rem)] w-[calc(100%-2rem)] rounded-md" />
|
||||
)}
|
||||
<video
|
||||
src={src}
|
||||
controls
|
||||
preload="metadata"
|
||||
className={`max-h-full max-w-full rounded-md transition-opacity ${loaded ? "opacity-100" : "opacity-0"}`}
|
||||
onLoadedMetadata={() => setLoaded(true)}
|
||||
onError={() => setError(true)}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ArtifactRenderer({
|
||||
artifact,
|
||||
content,
|
||||
@@ -79,17 +164,19 @@ function ArtifactRenderer({
|
||||
// Image: render directly from URL (no content fetch)
|
||||
if (classification.type === "image") {
|
||||
return (
|
||||
<div className="flex items-center justify-center p-4">
|
||||
{/* eslint-disable-next-line @next/next/no-img-element */}
|
||||
<img
|
||||
src={artifact.sourceUrl}
|
||||
alt={artifact.title}
|
||||
className="max-h-full max-w-full object-contain"
|
||||
/>
|
||||
</div>
|
||||
<ArtifactImage
|
||||
key={artifact.sourceUrl}
|
||||
src={artifact.sourceUrl}
|
||||
alt={artifact.title}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Video: render with <video> controls (no content fetch)
|
||||
if (classification.type === "video") {
|
||||
return <ArtifactVideo key={artifact.sourceUrl} src={artifact.sourceUrl} />;
|
||||
}
|
||||
|
||||
if (classification.type === "pdf" && pdfUrl) {
|
||||
// No sandbox — Chrome/Edge block PDF rendering in sandboxed iframes
|
||||
// (Chromium bug #413851). The blob URL has a null origin so it can't
|
||||
@@ -164,7 +251,16 @@ function ArtifactRenderer({
|
||||
|
||||
// CSV: pass with explicit metadata so CSVRenderer matches
|
||||
if (classification.type === "csv") {
|
||||
const csvMeta = { mimeType: "text/csv", filename: artifact.title };
|
||||
const normalizedMime = artifact.mimeType
|
||||
?.toLowerCase()
|
||||
.split(";")[0]
|
||||
?.trim();
|
||||
const csvMimeType =
|
||||
normalizedMime === "text/tab-separated-values" ||
|
||||
artifact.title.toLowerCase().endsWith(".tsv")
|
||||
? "text/tab-separated-values"
|
||||
: "text/csv";
|
||||
const csvMeta = { mimeType: csvMimeType, filename: artifact.title };
|
||||
const csvRenderer = globalRegistry.getRenderer(content, csvMeta);
|
||||
if (csvRenderer) {
|
||||
return <div className="p-4">{csvRenderer.render(content, csvMeta)}</div>;
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { ArtifactReactPreview } from "./ArtifactReactPreview";
|
||||
import {
|
||||
buildReactArtifactSrcDoc,
|
||||
collectPreviewStyles,
|
||||
transpileReactArtifactSource,
|
||||
} from "./reactArtifactPreview";
|
||||
|
||||
vi.mock("./reactArtifactPreview", () => ({
|
||||
buildReactArtifactSrcDoc: vi.fn(),
|
||||
collectPreviewStyles: vi.fn(),
|
||||
transpileReactArtifactSource: vi.fn(),
|
||||
}));
|
||||
|
||||
describe("ArtifactReactPreview", () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(collectPreviewStyles).mockReturnValue("<style>preview</style>");
|
||||
vi.mocked(buildReactArtifactSrcDoc).mockReturnValue("<html>preview</html>");
|
||||
});
|
||||
|
||||
it("renders an iframe preview after transpilation succeeds", async () => {
|
||||
vi.mocked(transpileReactArtifactSource).mockResolvedValue(
|
||||
"module.exports.default = function Artifact() { return null; };",
|
||||
);
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactReactPreview
|
||||
source="export default function Artifact() { return null; }"
|
||||
title="Artifact.tsx"
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(buildReactArtifactSrcDoc).toHaveBeenCalledWith(
|
||||
"module.exports.default = function Artifact() { return null; };",
|
||||
"Artifact.tsx",
|
||||
"<style>preview</style>",
|
||||
);
|
||||
});
|
||||
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
expect(iframe?.getAttribute("title")).toBe("Artifact.tsx preview");
|
||||
expect(iframe?.getAttribute("srcdoc")).toBe("<html>preview</html>");
|
||||
});
|
||||
|
||||
it("shows a readable error when transpilation fails", async () => {
|
||||
vi.mocked(transpileReactArtifactSource).mockRejectedValue(
|
||||
new Error("Transpile exploded"),
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactReactPreview
|
||||
source="export default function Artifact() {"
|
||||
title="Broken.tsx"
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Failed to render React preview")).toBeTruthy();
|
||||
});
|
||||
|
||||
expect(screen.getByText("Transpile exploded")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,970 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import {
|
||||
cleanup,
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@testing-library/react";
|
||||
import { ArtifactContent } from "../ArtifactContent";
|
||||
import type { ArtifactRef } from "../../../../store";
|
||||
import { classifyArtifact, type ArtifactClassification } from "../../helpers";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
import { codeRenderer } from "@/components/contextual/OutputRenderers/renderers/CodeRenderer";
|
||||
import { ArtifactReactPreview } from "../ArtifactReactPreview";
|
||||
|
||||
// Mock the renderers so we don't pull in the full renderer dependency tree
|
||||
vi.mock("@/components/contextual/OutputRenderers", () => ({
|
||||
globalRegistry: {
|
||||
getRenderer: vi.fn().mockReturnValue({
|
||||
render: vi.fn((_val: unknown, meta: Record<string, unknown>) => (
|
||||
<div data-testid="global-renderer">
|
||||
rendered:{String(meta?.mimeType ?? "unknown")}
|
||||
</div>
|
||||
)),
|
||||
}),
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock(
|
||||
"@/components/contextual/OutputRenderers/renderers/CodeRenderer",
|
||||
() => ({
|
||||
codeRenderer: {
|
||||
render: vi.fn((content: string) => (
|
||||
<div data-testid="code-renderer">{content}</div>
|
||||
)),
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
vi.mock("../ArtifactReactPreview", () => ({
|
||||
ArtifactReactPreview: vi.fn(
|
||||
({ source, title }: { source: string; title: string }) => (
|
||||
<div data-testid="react-preview" data-title={title}>
|
||||
{source}
|
||||
</div>
|
||||
),
|
||||
),
|
||||
}));
|
||||
|
||||
function makeArtifact(overrides?: Partial<ArtifactRef>): ArtifactRef {
|
||||
return {
|
||||
id: "file-001",
|
||||
title: "test.txt",
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file-001/download",
|
||||
origin: "agent",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
function makeClassification(
|
||||
overrides?: Partial<ArtifactClassification>,
|
||||
): ArtifactClassification {
|
||||
return {
|
||||
type: "text",
|
||||
icon: vi.fn(() => null) as unknown as ArtifactClassification["icon"],
|
||||
label: "Text",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("ArtifactContent", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("file content here"),
|
||||
blob: () => Promise.resolve(new Blob(["content"])),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
cleanup();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
// ── Image ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders image artifact as img tag with loading skeleton", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-001",
|
||||
title: "photo.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-001/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
expect(img).toBeTruthy();
|
||||
expect(img?.getAttribute("src")).toBe(
|
||||
"/api/proxy/api/workspace/files/img-001/download",
|
||||
);
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("image artifact shows loading skeleton before image loads", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-skeleton",
|
||||
title: "photo.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-skeleton/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
// Skeleton uses animate-pulse class
|
||||
const skeleton = container.querySelector('[class*="animate-pulse"]');
|
||||
expect(skeleton).toBeTruthy();
|
||||
});
|
||||
|
||||
it("image artifact shows error state when image fails to load", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-error",
|
||||
title: "broken.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-error/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
expect(img).toBeTruthy();
|
||||
fireEvent.error(img!);
|
||||
|
||||
const errorAlert = screen.queryByRole("alert");
|
||||
expect(errorAlert).toBeTruthy();
|
||||
expect(screen.queryByText("Failed to load image")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("image retry resets error and re-shows img", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "img-retry",
|
||||
title: "retry.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/img-retry/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "image" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const img = container.querySelector("img");
|
||||
fireEvent.error(img!);
|
||||
|
||||
// Should show error state
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load image")).toBeTruthy();
|
||||
});
|
||||
|
||||
// Click "Try again"
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
// Error should be cleared, img should reappear
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load image")).toBeNull();
|
||||
expect(container.querySelector("img")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ── Video ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders video artifact with video tag and controls", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-001",
|
||||
title: "clip.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-001/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
expect(video).toBeTruthy();
|
||||
expect(video?.hasAttribute("controls")).toBe(true);
|
||||
expect(video?.getAttribute("src")).toBe(
|
||||
"/api/proxy/api/workspace/files/vid-001/download",
|
||||
);
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("video shows loading skeleton before metadata loads", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-skel",
|
||||
title: "clip.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-skel/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const skeleton = container.querySelector('[class*="animate-pulse"]');
|
||||
expect(skeleton).toBeTruthy();
|
||||
|
||||
// After metadata loads, skeleton should disappear
|
||||
const video = container.querySelector("video");
|
||||
fireEvent.loadedMetadata(video!);
|
||||
|
||||
expect(container.querySelector('[class*="animate-pulse"]')).toBeNull();
|
||||
});
|
||||
|
||||
it("video shows error state when video fails to load", () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-error",
|
||||
title: "broken.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-error/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
expect(video).toBeTruthy();
|
||||
fireEvent.error(video!);
|
||||
|
||||
const errorAlert = screen.queryByRole("alert");
|
||||
expect(errorAlert).toBeTruthy();
|
||||
expect(screen.queryByText("Failed to load video")).toBeTruthy();
|
||||
});
|
||||
|
||||
it("video retry resets error and re-shows video", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "vid-retry",
|
||||
title: "retry.mp4",
|
||||
mimeType: "video/mp4",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/vid-retry/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const video = container.querySelector("video");
|
||||
fireEvent.error(video!);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load video")).toBeTruthy();
|
||||
});
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: /try again/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Failed to load video")).toBeNull();
|
||||
expect(container.querySelector("video")).toBeTruthy();
|
||||
});
|
||||
});
|
||||
|
||||
// ── PDF ───────────────────────────────────────────────────────────
|
||||
|
||||
it("renders PDF artifact in unsandboxed iframe with blob URL", async () => {
|
||||
const blobUrl = "blob:http://localhost/fake-pdf-blob";
|
||||
vi.stubGlobal(
|
||||
"URL",
|
||||
Object.assign(URL, {
|
||||
createObjectURL: vi.fn().mockReturnValue(blobUrl),
|
||||
revokeObjectURL: vi.fn(),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "pdf-render",
|
||||
title: "report.pdf",
|
||||
mimeType: "application/pdf",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/pdf-render/download",
|
||||
});
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("src")).toBe(blobUrl);
|
||||
// No sandbox attribute — Chrome blocks PDF in sandboxed iframes
|
||||
expect(iframe?.hasAttribute("sandbox")).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Fetch error ───────────────────────────────────────────────────
|
||||
|
||||
it("shows error state with retry button on fetch failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "error-content-test" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const errorText = await screen.findByText("Failed to load content");
|
||||
expect(errorText).toBeTruthy();
|
||||
|
||||
const retryButtons = screen.getAllByRole("button", { name: /try again/i });
|
||||
expect(retryButtons.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
// ── HTML ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders HTML content in sandboxed iframe", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () =>
|
||||
Promise.resolve("<html><body><h1>Hello World</h1></body></html>"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "html-001",
|
||||
title: "page.html",
|
||||
mimeType: "text/html",
|
||||
});
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTitle("page.html");
|
||||
const iframe = container.querySelector("iframe");
|
||||
expect(iframe).toBeTruthy();
|
||||
expect(iframe?.getAttribute("sandbox")).toBe("allow-scripts");
|
||||
});
|
||||
|
||||
// ── Source view ───────────────────────────────────────────────────
|
||||
|
||||
it("renders source view as pre tag", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("source code here"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "source-view-test" });
|
||||
const classification = makeClassification({
|
||||
type: "html",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={true}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByText("source code here");
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("source code here");
|
||||
});
|
||||
|
||||
// ── React ─────────────────────────────────────────────────────────
|
||||
|
||||
it("renders react artifacts via ArtifactReactPreview", async () => {
|
||||
const jsxSource = "export default function App() { return <div>Hi</div>; }";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsxSource),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "react-001",
|
||||
title: "App.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = makeClassification({ type: "react" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const preview = await screen.findByTestId("react-preview");
|
||||
expect(preview).toBeTruthy();
|
||||
expect(preview.textContent).toContain(jsxSource);
|
||||
expect(preview.getAttribute("data-title")).toBe("App.tsx");
|
||||
});
|
||||
|
||||
it("routes a concrete props-based TSX artifact into ArtifactReactPreview", async () => {
|
||||
const jsxSource = `
|
||||
import React, { FC, useState } from "react";
|
||||
|
||||
interface ArtifactFile {
|
||||
id: string;
|
||||
name: string;
|
||||
mimeType: string;
|
||||
url: string;
|
||||
sizeBytes: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
files: ArtifactFile[];
|
||||
onSelect: (file: ArtifactFile) => void;
|
||||
}
|
||||
|
||||
export const previewProps: Props = {
|
||||
files: [
|
||||
{
|
||||
id: "1",
|
||||
name: "report.png",
|
||||
mimeType: "image/png",
|
||||
url: "/report.png",
|
||||
sizeBytes: 2048,
|
||||
},
|
||||
],
|
||||
onSelect: () => {},
|
||||
};
|
||||
|
||||
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
|
||||
const handleClick = (file: ArtifactFile) => {
|
||||
setSelected(file.id);
|
||||
onSelect(file);
|
||||
};
|
||||
|
||||
return (
|
||||
<ul>
|
||||
{files.map((file) => (
|
||||
<li key={file.id} onClick={() => handleClick(file)}>
|
||||
<span>{selected === file.id ? "selected" : file.name}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
};
|
||||
|
||||
export default ArtifactList;
|
||||
`;
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsxSource),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "react-props-001",
|
||||
title: "ArtifactList.tsx",
|
||||
mimeType: "text/tsx",
|
||||
});
|
||||
const classification = classifyArtifact(artifact.mimeType, artifact.title);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const preview = await screen.findByTestId("react-preview");
|
||||
expect(preview.textContent).toContain("previewProps");
|
||||
expect(preview.getAttribute("data-title")).toBe("ArtifactList.tsx");
|
||||
expect(vi.mocked(ArtifactReactPreview).mock.calls[0]?.[0]).toEqual(
|
||||
expect.objectContaining({
|
||||
source: expect.stringContaining("export const previewProps"),
|
||||
title: "ArtifactList.tsx",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
// ── Code ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders code artifacts via codeRenderer", async () => {
|
||||
const code = 'def hello():\n print("hi")';
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(code),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "code-render-001",
|
||||
title: "script.py",
|
||||
mimeType: "text/x-python",
|
||||
});
|
||||
const classification = makeClassification({ type: "code" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("code-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain(code);
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
filename: "events.jsonl",
|
||||
mimeType: "application/x-ndjson",
|
||||
content: '{"event":"start"}\n{"event":"finish"}',
|
||||
},
|
||||
{
|
||||
filename: ".env.local",
|
||||
mimeType: "text/plain",
|
||||
content: "OPENAI_API_KEY=test\nDEBUG=true",
|
||||
},
|
||||
{
|
||||
filename: "Dockerfile",
|
||||
mimeType: "text/plain",
|
||||
content: "FROM node:20\nRUN pnpm install",
|
||||
},
|
||||
{
|
||||
filename: "schema.graphql",
|
||||
mimeType: "text/plain",
|
||||
content: "type Query { viewer: User }",
|
||||
},
|
||||
])(
|
||||
"renders concrete code artifact $filename through codeRenderer",
|
||||
async ({ filename, mimeType, content }) => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(content),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: `code-${filename}`,
|
||||
title: filename,
|
||||
mimeType,
|
||||
});
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTestId("code-renderer");
|
||||
|
||||
expect(classification.type).toBe("code");
|
||||
expect(vi.mocked(codeRenderer.render)).toHaveBeenCalledWith(
|
||||
content,
|
||||
expect.objectContaining({
|
||||
filename,
|
||||
mimeType,
|
||||
type: "code",
|
||||
}),
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
// ── JSON ──────────────────────────────────────────────────────────
|
||||
|
||||
it("renders valid JSON via globalRegistry", async () => {
|
||||
const jsonContent = JSON.stringify({ key: "value" }, null, 2);
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(jsonContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "json-render-001",
|
||||
title: "data.json",
|
||||
mimeType: "application/json",
|
||||
});
|
||||
const classification = makeClassification({ type: "json" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("application/json");
|
||||
});
|
||||
|
||||
it("renders invalid JSON as fallback pre tag", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
);
|
||||
const originalImpl = vi
|
||||
.mocked(globalRegistry.getRenderer)
|
||||
.getMockImplementation();
|
||||
|
||||
// For invalid JSON, JSON.parse throws, then the registry fallback
|
||||
// also returns null → falls through to <pre>
|
||||
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("{invalid json!!!"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "json-invalid-001",
|
||||
title: "bad.json",
|
||||
mimeType: "application/json",
|
||||
});
|
||||
const classification = makeClassification({ type: "json" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("{invalid json!!!");
|
||||
});
|
||||
|
||||
// Restore
|
||||
if (originalImpl) {
|
||||
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
|
||||
}
|
||||
});
|
||||
|
||||
// ── CSV ───────────────────────────────────────────────────────────
|
||||
|
||||
it("renders CSV via globalRegistry with text/csv metadata", async () => {
|
||||
const csvContent = "Name,Age\nAlice,30\nBob,25";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(csvContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "csv-render-001",
|
||||
title: "data.csv",
|
||||
mimeType: "text/csv",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "csv",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/csv");
|
||||
});
|
||||
|
||||
it("renders TSV via globalRegistry with tab-separated metadata", async () => {
|
||||
const tsvContent = "Name\tAge\nAlice\t30\nBob\t25";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(tsvContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "tsv-render-001",
|
||||
title: "data.tsv",
|
||||
mimeType: "text/tab-separated-values",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "csv",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/tab-separated-values");
|
||||
});
|
||||
|
||||
// ── Markdown ──────────────────────────────────────────────────────
|
||||
|
||||
it("renders markdown via globalRegistry", async () => {
|
||||
const mdContent = "# Hello\n\nThis is **markdown**.";
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(mdContent),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "md-render-001",
|
||||
title: "README.md",
|
||||
mimeType: "text/markdown",
|
||||
});
|
||||
const classification = makeClassification({
|
||||
type: "markdown",
|
||||
hasSourceToggle: true,
|
||||
});
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
expect(rendered.textContent).toContain("text/markdown");
|
||||
});
|
||||
|
||||
// ── Text fallback ─────────────────────────────────────────────────
|
||||
|
||||
it("renders text artifacts via globalRegistry fallback", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("plain text content"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "text-render-001",
|
||||
title: "notes.txt",
|
||||
mimeType: "text/plain",
|
||||
});
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
const rendered = await screen.findByTestId("global-renderer");
|
||||
expect(rendered).toBeTruthy();
|
||||
});
|
||||
|
||||
it.each([
|
||||
{
|
||||
filename: "calendar.ics",
|
||||
mimeType: "text/calendar",
|
||||
content: "BEGIN:VCALENDAR\nVERSION:2.0\nEND:VCALENDAR",
|
||||
},
|
||||
{
|
||||
filename: "contact.vcf",
|
||||
mimeType: "text/vcard",
|
||||
content: "BEGIN:VCARD\nVERSION:4.0\nFN:Alice Example\nEND:VCARD",
|
||||
},
|
||||
])(
|
||||
"renders concrete text artifact $filename through the global renderer path",
|
||||
async ({ filename, mimeType, content }) => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(content),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: `text-${filename}`,
|
||||
title: filename,
|
||||
mimeType,
|
||||
});
|
||||
const classification = classifyArtifact(
|
||||
artifact.mimeType,
|
||||
artifact.title,
|
||||
);
|
||||
|
||||
render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await screen.findByTestId("global-renderer");
|
||||
|
||||
expect(classification.type).toBe("text");
|
||||
expect(vi.mocked(globalRegistry.getRenderer)).toHaveBeenCalledWith(
|
||||
content,
|
||||
expect.objectContaining({
|
||||
filename,
|
||||
mimeType,
|
||||
}),
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
it("falls back to pre tag when no renderer matches", async () => {
|
||||
const { globalRegistry } = await import(
|
||||
"@/components/contextual/OutputRenderers"
|
||||
);
|
||||
const originalImpl = vi
|
||||
.mocked(globalRegistry.getRenderer)
|
||||
.getMockImplementation();
|
||||
|
||||
vi.mocked(globalRegistry.getRenderer).mockReturnValue(null);
|
||||
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("raw content fallback"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({
|
||||
id: "fallback-pre-001",
|
||||
title: "unknown.txt",
|
||||
mimeType: "text/plain",
|
||||
});
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
const { container } = render(
|
||||
<ArtifactContent
|
||||
artifact={artifact}
|
||||
isSourceView={false}
|
||||
classification={classification}
|
||||
/>,
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
const pre = container.querySelector("pre");
|
||||
expect(pre).toBeTruthy();
|
||||
expect(pre?.textContent).toBe("raw content fallback");
|
||||
});
|
||||
|
||||
// Restore
|
||||
if (originalImpl) {
|
||||
vi.mocked(globalRegistry.getRenderer).mockImplementation(originalImpl);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -3,6 +3,7 @@ import { renderHook, waitFor, act } from "@testing-library/react";
|
||||
import {
|
||||
useArtifactContent,
|
||||
getCachedArtifactContent,
|
||||
clearContentCache,
|
||||
} from "../useArtifactContent";
|
||||
import type { ArtifactRef } from "../../../../store";
|
||||
import type { ArtifactClassification } from "../../helpers";
|
||||
@@ -33,6 +34,7 @@ function makeClassification(
|
||||
|
||||
describe("useArtifactContent", () => {
|
||||
beforeEach(() => {
|
||||
clearContentCache();
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
@@ -44,6 +46,7 @@ describe("useArtifactContent", () => {
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
clearContentCache();
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
@@ -109,9 +112,12 @@ describe("useArtifactContent", () => {
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
});
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.content).toBeNull();
|
||||
@@ -132,6 +138,176 @@ describe("useArtifactContent", () => {
|
||||
expect(getCachedArtifactContent("cache-test")).toBe("file content here");
|
||||
});
|
||||
|
||||
it("sets error on fetch failure for HTML artifacts (stale artifact)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "stale-html-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.content).toBeNull();
|
||||
});
|
||||
|
||||
it("sets error on network failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockRejectedValue(new Error("Network error")),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "network-error-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("Network error");
|
||||
expect(result.current.content).toBeNull();
|
||||
});
|
||||
|
||||
it("retries transient HTML fetch failures before surfacing an error", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount < 3) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 503,
|
||||
headers: {
|
||||
get: () => "application/json",
|
||||
},
|
||||
json: () => Promise.resolve({ detail: "temporary upstream error" }),
|
||||
});
|
||||
}
|
||||
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("<html>ok now</html>"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "transient-html-retry" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.content).toBe("<html>ok now</html>");
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(callCount).toBe(3);
|
||||
expect(result.current.error).toBeNull();
|
||||
});
|
||||
|
||||
it("surfaces backend error detail from JSON responses", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 404,
|
||||
headers: {
|
||||
get: () => "application/json",
|
||||
},
|
||||
json: () => Promise.resolve({ detail: "File not found" }),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "json-error-detail" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("404");
|
||||
expect(result.current.error).toContain("File not found");
|
||||
});
|
||||
|
||||
it("retry after 404 on HTML artifact clears cache and re-fetches", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
if (callCount === 1) {
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 404,
|
||||
text: () => Promise.resolve("Not found"),
|
||||
});
|
||||
}
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve("<html>recovered</html>"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "retry-html-artifact" });
|
||||
const classification = makeClassification({ type: "html" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
});
|
||||
|
||||
act(() => {
|
||||
result.current.retry();
|
||||
});
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.content).toBe("<html>recovered</html>");
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toBeNull();
|
||||
});
|
||||
|
||||
it("retry clears cache and re-fetches", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
@@ -164,4 +340,162 @@ describe("useArtifactContent", () => {
|
||||
expect(result.current.content).toBe("response 2");
|
||||
});
|
||||
});
|
||||
|
||||
// ── Non-transient errors ──────────────────────────────────────────
|
||||
|
||||
it("rejects immediately on 403 without retrying", async () => {
|
||||
let callCount = 0;
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation(() => {
|
||||
callCount++;
|
||||
return Promise.resolve({
|
||||
ok: false,
|
||||
status: 403,
|
||||
text: () => Promise.resolve("Forbidden"),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "forbidden-no-retry" });
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(callCount).toBe(1);
|
||||
expect(result.current.error).toContain("403");
|
||||
});
|
||||
|
||||
// ── Video skip-fetch ──────────────────────────────────────────────
|
||||
|
||||
it("skips fetch for video artifacts (like image)", async () => {
|
||||
const artifact = makeArtifact({
|
||||
id: "video-skip",
|
||||
mimeType: "video/mp4",
|
||||
});
|
||||
const classification = makeClassification({ type: "video" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
expect(result.current.content).toBeNull();
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
expect(fetch).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
// ── PDF error paths ───────────────────────────────────────────────
|
||||
|
||||
it("sets error on PDF fetch failure (non-2xx)", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockResolvedValue({
|
||||
ok: false,
|
||||
status: 500,
|
||||
text: () => Promise.resolve("Server Error"),
|
||||
}),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "pdf-error" });
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("500");
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
});
|
||||
|
||||
it("sets error on PDF network failure", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockRejectedValue(new Error("PDF network failure")),
|
||||
);
|
||||
|
||||
const artifact = makeArtifact({ id: "pdf-network-error" });
|
||||
const classification = makeClassification({ type: "pdf" });
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
|
||||
await waitFor(
|
||||
() => {
|
||||
expect(result.current.error).toBeTruthy();
|
||||
},
|
||||
{ timeout: 2500 },
|
||||
);
|
||||
|
||||
expect(result.current.error).toContain("PDF network failure");
|
||||
expect(result.current.pdfUrl).toBeNull();
|
||||
});
|
||||
|
||||
// ── LRU cache eviction ────────────────────────────────────────────
|
||||
|
||||
it("evicts oldest entry when cache exceeds 12 items", async () => {
|
||||
vi.stubGlobal(
|
||||
"fetch",
|
||||
vi.fn().mockImplementation((url: string) => {
|
||||
const fileId = url.match(/files\/([^/]+)\/download/)?.[1] ?? "unknown";
|
||||
return Promise.resolve({
|
||||
ok: true,
|
||||
text: () => Promise.resolve(`content-${fileId}`),
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
const classification = makeClassification({ type: "text" });
|
||||
|
||||
// Fill the cache with 12 entries (cache max = 12)
|
||||
for (let i = 0; i < 12; i++) {
|
||||
const artifact = makeArtifact({
|
||||
id: `lru-${i}`,
|
||||
sourceUrl: `/api/proxy/api/workspace/files/lru-${i}/download`,
|
||||
});
|
||||
const { result } = renderHook(() =>
|
||||
useArtifactContent(artifact, classification),
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(result.current.isLoading).toBe(false);
|
||||
});
|
||||
}
|
||||
|
||||
// All 12 should be cached
|
||||
expect(getCachedArtifactContent("lru-0")).toBe("content-lru-0");
|
||||
expect(getCachedArtifactContent("lru-11")).toBe("content-lru-11");
|
||||
|
||||
// Adding a 13th should evict lru-0 (the oldest)
|
||||
const artifact13 = makeArtifact({
|
||||
id: "lru-12",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/lru-12/download",
|
||||
});
|
||||
const { result: result13 } = renderHook(() =>
|
||||
useArtifactContent(artifact13, classification),
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(result13.current.isLoading).toBe(false);
|
||||
});
|
||||
|
||||
expect(getCachedArtifactContent("lru-0")).toBeUndefined();
|
||||
expect(getCachedArtifactContent("lru-1")).toBe("content-lru-1");
|
||||
expect(getCachedArtifactContent("lru-12")).toBe("content-lru-12");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -85,4 +85,35 @@ describe("buildReactArtifactSrcDoc", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("box-sizing: border-box");
|
||||
});
|
||||
|
||||
it("supports a named previewProps export in the runtime", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("moduleExports.previewProps");
|
||||
expect(doc).toContain("React.createElement(Component, previewProps || {})");
|
||||
});
|
||||
|
||||
it("includes a helpful message for components that expect props", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("This component appears to expect props.");
|
||||
expect(doc).toContain("previewProps");
|
||||
});
|
||||
|
||||
it("checks componentExpectsProps on the raw component before wrapping", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("RawComponent.length > 0");
|
||||
expect(doc).toContain("wrapWithProviders(RawComponent");
|
||||
});
|
||||
|
||||
it("wrapWithProviders forwards props to the wrapped component", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain("function WrappedArtifactPreview(props)");
|
||||
expect(doc).toContain("React.createElement(Component, props)");
|
||||
});
|
||||
|
||||
it("supports named exported components and provider wrappers in the runtime", () => {
|
||||
const doc = buildReactArtifactSrcDoc("module.exports = {};", "A", STYLES);
|
||||
expect(doc).toContain('name.endsWith("Provider")');
|
||||
expect(doc).toContain("/^[A-Z]/.test(name)");
|
||||
expect(doc).toContain("wrapWithProviders");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -169,8 +169,8 @@ export function buildReactArtifactSrcDoc(
|
||||
return Component;
|
||||
}
|
||||
|
||||
return function WrappedArtifactPreview() {
|
||||
let tree = React.createElement(Component);
|
||||
return function WrappedArtifactPreview(props) {
|
||||
let tree = React.createElement(Component, props);
|
||||
|
||||
for (let i = providers.length - 1; i >= 0; i -= 1) {
|
||||
tree = React.createElement(providers[i], null, tree);
|
||||
@@ -180,6 +180,17 @@ export function buildReactArtifactSrcDoc(
|
||||
};
|
||||
}
|
||||
|
||||
function getPreviewProps(moduleExports) {
|
||||
if (
|
||||
moduleExports.previewProps &&
|
||||
typeof moduleExports.previewProps === "object"
|
||||
) {
|
||||
return moduleExports.previewProps;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
function require(name) {
|
||||
if (name === "react") {
|
||||
return React;
|
||||
@@ -235,6 +246,11 @@ export function buildReactArtifactSrcDoc(
|
||||
|
||||
render() {
|
||||
if (this.state.error) {
|
||||
const propsHelp =
|
||||
this.props.componentExpectsProps && !this.props.hasPreviewProps
|
||||
? "\\n\\nThis component appears to expect props. Export a named previewProps object with sample values to render it in artifact preview."
|
||||
: "";
|
||||
|
||||
return React.createElement(
|
||||
"div",
|
||||
{
|
||||
@@ -249,7 +265,9 @@ export function buildReactArtifactSrcDoc(
|
||||
whiteSpace: "pre-wrap",
|
||||
},
|
||||
},
|
||||
this.state.error.stack || this.state.error.message || String(this.state.error),
|
||||
(this.state.error.stack ||
|
||||
this.state.error.message ||
|
||||
String(this.state.error)) + propsHelp,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -296,16 +314,19 @@ export function buildReactArtifactSrcDoc(
|
||||
moduleExports.App = executionResult.app;
|
||||
}
|
||||
|
||||
const Component = wrapWithProviders(
|
||||
getRenderableCandidate(moduleExports),
|
||||
moduleExports,
|
||||
);
|
||||
const RawComponent = getRenderableCandidate(moduleExports);
|
||||
const componentExpectsProps = RawComponent.length > 0;
|
||||
const Component = wrapWithProviders(RawComponent, moduleExports);
|
||||
const previewProps = getPreviewProps(moduleExports);
|
||||
|
||||
ReactDOM.createRoot(rootElement).render(
|
||||
React.createElement(
|
||||
PreviewErrorBoundary,
|
||||
null,
|
||||
React.createElement(Component),
|
||||
{
|
||||
componentExpectsProps: componentExpectsProps,
|
||||
hasPreviewProps: previewProps != null,
|
||||
},
|
||||
React.createElement(Component, previewProps || {}),
|
||||
),
|
||||
);
|
||||
} catch (error) {
|
||||
|
||||
@@ -48,4 +48,104 @@ describe("transpileReactArtifactSource", () => {
|
||||
expect(out).not.toContain(": string");
|
||||
expect(out).toContain("function greet(name)");
|
||||
});
|
||||
|
||||
it("transpiles a concrete props-based artifact with previewProps", async () => {
|
||||
const src = `
|
||||
import React, { FC, useState } from "react";
|
||||
|
||||
interface ArtifactFile {
|
||||
id: string;
|
||||
name: string;
|
||||
mimeType: string;
|
||||
url: string;
|
||||
sizeBytes: number;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
files: ArtifactFile[];
|
||||
onSelect: (file: ArtifactFile) => void;
|
||||
}
|
||||
|
||||
export const previewProps: Props = {
|
||||
files: [
|
||||
{
|
||||
id: "1",
|
||||
name: "report.png",
|
||||
mimeType: "image/png",
|
||||
url: "/report.png",
|
||||
sizeBytes: 2048,
|
||||
},
|
||||
],
|
||||
onSelect: () => {},
|
||||
};
|
||||
|
||||
const ArtifactList: FC<Props> = ({ files, onSelect }) => {
|
||||
const [selected, setSelected] = useState<string | null>(null);
|
||||
|
||||
const handleClick = (file: ArtifactFile) => {
|
||||
setSelected(file.id);
|
||||
onSelect(file);
|
||||
};
|
||||
|
||||
return (
|
||||
<ul>
|
||||
{files.map((file) => (
|
||||
<li key={file.id} onClick={() => handleClick(file)}>
|
||||
<span>{selected === file.id ? "selected" : file.name}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
};
|
||||
|
||||
export default ArtifactList;
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "ArtifactList.tsx");
|
||||
|
||||
expect(out).toContain("exports.previewProps");
|
||||
expect(out).toContain("exports.default = ArtifactList");
|
||||
expect(out).toContain("useState");
|
||||
expect(out).not.toContain("interface Props");
|
||||
expect(out).not.toContain("interface ArtifactFile");
|
||||
});
|
||||
|
||||
it("transpiles a named export artifact without a default export", async () => {
|
||||
const src = `
|
||||
export function ResultsGrid() {
|
||||
return (
|
||||
<section>
|
||||
<h1>Results</h1>
|
||||
<p>Named export preview</p>
|
||||
</section>
|
||||
);
|
||||
}
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "ResultsGrid.tsx");
|
||||
|
||||
expect(out).toContain("exports.ResultsGrid = ResultsGrid");
|
||||
expect(out).toMatch(/\.createElement\(/);
|
||||
expect(out).not.toContain("<section>");
|
||||
});
|
||||
|
||||
it("transpiles a provider-wrapped artifact with separate provider and component exports", async () => {
|
||||
const src = `
|
||||
import React from "react";
|
||||
|
||||
export function DemoProvider({ children }: { children: React.ReactNode }) {
|
||||
return <div data-theme="demo">{children}</div>;
|
||||
}
|
||||
|
||||
export function DashboardCard() {
|
||||
return <main>Provider-backed preview</main>;
|
||||
}
|
||||
`;
|
||||
|
||||
const out = await transpileReactArtifactSource(src, "DashboardCard.tsx");
|
||||
|
||||
expect(out).toContain("exports.DemoProvider = DemoProvider");
|
||||
expect(out).toContain("exports.DashboardCard = DashboardCard");
|
||||
expect(out).not.toContain("React.ReactNode");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,12 +7,116 @@ import type { ArtifactClassification } from "../helpers";
|
||||
// Cap on cached text artifacts. Long sessions with many large artifacts
|
||||
// would otherwise hold every opened one in memory.
|
||||
const CONTENT_CACHE_MAX = 12;
|
||||
const CONTENT_FETCH_MAX_RETRIES = 2;
|
||||
const CONTENT_FETCH_RETRY_DELAY_MS = 500;
|
||||
|
||||
// Module-level LRU keyed by artifact id so a sibling action (e.g. Copy
|
||||
// in ArtifactPanelHeader) can read what the panel already fetched without
|
||||
// re-hitting the network.
|
||||
const contentCache = new Map<string, string>();
|
||||
|
||||
class ArtifactFetchError extends Error {}
|
||||
|
||||
function isTransientArtifactFetchStatus(status: number): boolean {
|
||||
return status === 408 || status === 429 || status >= 500;
|
||||
}
|
||||
|
||||
function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
function getArtifactErrorMessage(body: unknown): string | null {
|
||||
if (typeof body === "string") {
|
||||
const trimmed = body.replace(/\s+/g, " ").trim();
|
||||
return trimmed || null;
|
||||
}
|
||||
|
||||
if (!body || typeof body !== "object") return null;
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
typeof body.detail === "string" &&
|
||||
body.detail.trim().length > 0
|
||||
) {
|
||||
return body.detail.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"error" in body &&
|
||||
typeof body.error === "string" &&
|
||||
body.error.trim().length > 0
|
||||
) {
|
||||
return body.error.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
body.detail &&
|
||||
typeof body.detail === "object" &&
|
||||
"message" in body.detail &&
|
||||
typeof body.detail.message === "string" &&
|
||||
body.detail.message.trim().length > 0
|
||||
) {
|
||||
return body.detail.message.trim();
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
async function parseArtifactFetchError(response: Response): Promise<string> {
|
||||
const prefix = `Failed to fetch: ${response.status}`;
|
||||
const contentType =
|
||||
response.headers?.get?.("content-type")?.toLowerCase() ?? "";
|
||||
|
||||
try {
|
||||
if (
|
||||
contentType.includes("application/json") &&
|
||||
typeof response.json === "function"
|
||||
) {
|
||||
const body = await response.json();
|
||||
const detail = getArtifactErrorMessage(body);
|
||||
return detail ? `${prefix} ${detail}` : prefix;
|
||||
}
|
||||
|
||||
if (typeof response.text === "function") {
|
||||
const text = await response.text();
|
||||
const detail = getArtifactErrorMessage(text);
|
||||
return detail ? `${prefix} ${detail}` : prefix;
|
||||
}
|
||||
} catch {
|
||||
return prefix;
|
||||
}
|
||||
|
||||
return prefix;
|
||||
}
|
||||
|
||||
async function fetchArtifactResponse(url: string): Promise<Response> {
|
||||
for (let attempt = 0; attempt <= CONTENT_FETCH_MAX_RETRIES; attempt++) {
|
||||
try {
|
||||
const response = await fetch(url);
|
||||
if (response.ok) return response;
|
||||
|
||||
if (
|
||||
!isTransientArtifactFetchStatus(response.status) ||
|
||||
attempt === CONTENT_FETCH_MAX_RETRIES
|
||||
) {
|
||||
throw new ArtifactFetchError(await parseArtifactFetchError(response));
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof ArtifactFetchError) throw error;
|
||||
if (attempt === CONTENT_FETCH_MAX_RETRIES) {
|
||||
throw error instanceof Error
|
||||
? error
|
||||
: new Error("Failed to fetch artifact");
|
||||
}
|
||||
}
|
||||
|
||||
await sleep(CONTENT_FETCH_RETRY_DELAY_MS);
|
||||
}
|
||||
|
||||
throw new Error("Failed to fetch artifact");
|
||||
}
|
||||
|
||||
export function getCachedArtifactContent(id: string): string | undefined {
|
||||
return contentCache.get(id);
|
||||
}
|
||||
@@ -64,7 +168,7 @@ export function useArtifactContent(
|
||||
}, [artifact.id, isLoading]);
|
||||
|
||||
useEffect(() => {
|
||||
if (classification.type === "image") {
|
||||
if (classification.type === "image" || classification.type === "video") {
|
||||
setContent(null);
|
||||
setPdfUrl(null);
|
||||
setError(null);
|
||||
@@ -80,11 +184,8 @@ export function useArtifactContent(
|
||||
let objectUrl: string | null = null;
|
||||
setContent(null);
|
||||
setPdfUrl(null);
|
||||
fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
|
||||
return res.blob();
|
||||
})
|
||||
fetchArtifactResponse(artifact.sourceUrl)
|
||||
.then((res) => res.blob())
|
||||
.then((blob) => {
|
||||
objectUrl = URL.createObjectURL(blob);
|
||||
if (cancelled) {
|
||||
@@ -121,11 +222,8 @@ export function useArtifactContent(
|
||||
cancelled = true;
|
||||
};
|
||||
}
|
||||
fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Failed to fetch: ${res.status}`);
|
||||
return res.text();
|
||||
})
|
||||
fetchArtifactResponse(artifact.sourceUrl)
|
||||
.then((res) => res.text())
|
||||
.then((text) => {
|
||||
if (!cancelled) {
|
||||
if (cache.size >= CONTENT_CACHE_MAX) {
|
||||
|
||||
@@ -1,5 +1,31 @@
|
||||
import type { ArtifactRef } from "../../store";
|
||||
|
||||
const MAX_RETRIES = 2;
|
||||
const RETRY_DELAY_MS = 500;
|
||||
|
||||
function isTransientError(status: number): boolean {
|
||||
return status >= 500 || status === 408 || status === 429;
|
||||
}
|
||||
|
||||
class DownloadError extends Error {}
|
||||
|
||||
async function fetchWithRetry(url: string, retries: number): Promise<Response> {
|
||||
for (let attempt = 0; attempt <= retries; attempt++) {
|
||||
try {
|
||||
const res = await fetch(url);
|
||||
if (res.ok) return res;
|
||||
if (!isTransientError(res.status) || attempt === retries) {
|
||||
throw new DownloadError(`Download failed: ${res.status}`);
|
||||
}
|
||||
} catch (error) {
|
||||
if (error instanceof DownloadError) throw error;
|
||||
if (attempt === retries) throw error;
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, RETRY_DELAY_MS));
|
||||
}
|
||||
throw new Error("Unreachable");
|
||||
}
|
||||
|
||||
/**
|
||||
* Trigger a file download from an artifact URL.
|
||||
*
|
||||
@@ -7,26 +33,28 @@ import type { ArtifactRef } from "../../store";
|
||||
* ignores the `download` attribute on cross-origin responses (GCS signed
|
||||
* URLs), and some browsers require the anchor to be attached to the DOM
|
||||
* before `.click()` fires the download.
|
||||
*
|
||||
* Retries up to {@link MAX_RETRIES} times on transient server errors (5xx,
|
||||
* 408, 429) to handle intermittent proxy/GCS failures.
|
||||
*/
|
||||
export function downloadArtifact(artifact: ArtifactRef): Promise<void> {
|
||||
// Replace path separators, Windows-reserved chars, control chars, and
|
||||
// parent-dir sequences so the browser-assigned filename is safe to write
|
||||
// anywhere on the user's filesystem.
|
||||
const safeName =
|
||||
artifact.title
|
||||
.replace(/\.\./g, "_")
|
||||
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
|
||||
.replace(/^\.+/, "") || "download";
|
||||
return fetch(artifact.sourceUrl)
|
||||
.then((res) => {
|
||||
if (!res.ok) throw new Error(`Download failed: ${res.status}`);
|
||||
return res.blob();
|
||||
})
|
||||
const collapsedDots = artifact.title.replace(/\.\./g, "");
|
||||
const hasVisibleName = collapsedDots.replace(/^\.+/, "").length > 0;
|
||||
const safeName = artifact.title
|
||||
.replace(/\.\./g, "_")
|
||||
.replace(/[\\/:*?"<>|\x00-\x1f]/g, "_")
|
||||
.replace(/^\.+/, "");
|
||||
|
||||
return fetchWithRetry(artifact.sourceUrl, MAX_RETRIES)
|
||||
.then((res) => res.blob())
|
||||
.then((blob) => {
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement("a");
|
||||
a.href = url;
|
||||
a.download = safeName;
|
||||
a.download = safeName && hasVisibleName ? safeName : "download";
|
||||
document.body.appendChild(a);
|
||||
a.click();
|
||||
a.remove();
|
||||
|
||||
@@ -56,7 +56,7 @@ describe("classifyArtifact", () => {
|
||||
expect(classifyArtifact("application/octet-stream", "x").openable).toBe(
|
||||
false,
|
||||
);
|
||||
expect(classifyArtifact("video/mp4", "clip.mp4").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/mpeg", "track.mp3").openable).toBe(false);
|
||||
});
|
||||
|
||||
it("defaults unknown extension+MIME to download-only (not text)", () => {
|
||||
@@ -76,4 +76,398 @@ describe("classifyArtifact", () => {
|
||||
const c = classifyArtifact("text/plain", "data.csv");
|
||||
expect(c.type).toBe("csv");
|
||||
});
|
||||
|
||||
it("classifies video/mp4 as video (previewable)", () => {
|
||||
const c = classifyArtifact("video/mp4", "clip.mp4");
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("classifies video/webm as video (previewable)", () => {
|
||||
const c = classifyArtifact("video/webm", "clip.webm");
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
// ── Extension coverage ────────────────────────────────────────────
|
||||
|
||||
it("routes .htm as html (not just .html)", () => {
|
||||
const c = classifyArtifact(null, "page.htm");
|
||||
expect(c.type).toBe("html");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .json as json with source toggle", () => {
|
||||
const c = classifyArtifact(null, "config.json");
|
||||
expect(c.type).toBe("json");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .txt as text", () => {
|
||||
expect(classifyArtifact(null, "notes.txt").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes .log as text", () => {
|
||||
expect(classifyArtifact(null, "server.log").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes .mdx as markdown", () => {
|
||||
expect(classifyArtifact(null, "docs.mdx").type).toBe("markdown");
|
||||
});
|
||||
|
||||
it("routes browser-safe video extensions to video", () => {
|
||||
for (const ext of [".mp4", ".webm", ".m4v"]) {
|
||||
const c = classifyArtifact(null, `clip${ext}`);
|
||||
expect(c.type).toBe("video");
|
||||
expect(c.openable).toBe(true);
|
||||
}
|
||||
});
|
||||
|
||||
it("keeps legacy or unsupported video extensions download-only", () => {
|
||||
for (const ext of [".ogg", ".mov", ".avi", ".mkv", ".flv", ".mpeg"]) {
|
||||
const c = classifyArtifact(null, `clip${ext}`);
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("routes all code extensions to code", () => {
|
||||
const codeExts = [
|
||||
"main.js",
|
||||
"app.ts",
|
||||
"theme.scss",
|
||||
"legacy.less",
|
||||
"schema.graphql",
|
||||
"query.gql",
|
||||
"api.proto",
|
||||
"main.dart",
|
||||
"lib.rb",
|
||||
"server.rs",
|
||||
"App.java",
|
||||
"main.c",
|
||||
"util.cpp",
|
||||
"header.h",
|
||||
"Program.cs",
|
||||
"index.php",
|
||||
"main.swift",
|
||||
"App.kt",
|
||||
"run.sh",
|
||||
"start.bash",
|
||||
"prompt.zsh",
|
||||
"config.toml",
|
||||
"settings.ini",
|
||||
"app.cfg",
|
||||
"query.sql",
|
||||
"analysis.r",
|
||||
"game.lua",
|
||||
"script.pl",
|
||||
"Calc.scala",
|
||||
];
|
||||
for (const file of codeExts) {
|
||||
expect(classifyArtifact(null, file).type).toBe("code");
|
||||
}
|
||||
});
|
||||
|
||||
it("routes config filenames and extensions to code", () => {
|
||||
const configFiles = [
|
||||
".env",
|
||||
".env.local",
|
||||
"app.properties",
|
||||
"service.conf",
|
||||
".gitignore",
|
||||
"Dockerfile",
|
||||
"Makefile",
|
||||
];
|
||||
|
||||
for (const file of configFiles) {
|
||||
expect(classifyArtifact(null, file).type).toBe("code");
|
||||
}
|
||||
});
|
||||
|
||||
it("routes .jsonl as code for now", () => {
|
||||
const c = classifyArtifact(null, "events.jsonl");
|
||||
expect(c.type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes .tsv as csv/spreadsheet", () => {
|
||||
const c = classifyArtifact(null, "table.tsv");
|
||||
expect(c.type).toBe("csv");
|
||||
expect(c.hasSourceToggle).toBe(true);
|
||||
});
|
||||
|
||||
it("routes .ics and .vcf as text", () => {
|
||||
expect(classifyArtifact(null, "calendar.ics").type).toBe("text");
|
||||
expect(classifyArtifact(null, "contact.vcf").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes all image extensions to image", () => {
|
||||
for (const ext of [".jpg", ".jpeg", ".gif", ".webp", ".bmp", ".ico"]) {
|
||||
expect(classifyArtifact(null, `file${ext}`).type).toBe("image");
|
||||
}
|
||||
});
|
||||
|
||||
// ── MIME fallback coverage ────────────────────────────────────────
|
||||
|
||||
it("routes application/json MIME to json", () => {
|
||||
const c = classifyArtifact("application/json", "noext");
|
||||
expect(c.type).toBe("json");
|
||||
});
|
||||
|
||||
it("routes text/x-* MIME prefix to code", () => {
|
||||
expect(classifyArtifact("text/x-python", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/x-c", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/x-java-source", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes react MIME types to react", () => {
|
||||
expect(classifyArtifact("text/jsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("text/tsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("application/jsx", "noext").type).toBe("react");
|
||||
expect(classifyArtifact("application/x-typescript-jsx", "noext").type).toBe(
|
||||
"react",
|
||||
);
|
||||
});
|
||||
|
||||
it("routes JavaScript/TypeScript MIME to code", () => {
|
||||
expect(classifyArtifact("application/javascript", "noext").type).toBe(
|
||||
"code",
|
||||
);
|
||||
expect(classifyArtifact("text/javascript", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("application/typescript", "noext").type).toBe(
|
||||
"code",
|
||||
);
|
||||
expect(classifyArtifact("text/typescript", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes XML MIME to code", () => {
|
||||
expect(classifyArtifact("application/xml", "noext").type).toBe("code");
|
||||
expect(classifyArtifact("text/xml", "noext").type).toBe("code");
|
||||
});
|
||||
|
||||
it("routes text/x-markdown MIME to markdown", () => {
|
||||
expect(classifyArtifact("text/x-markdown", "noext").type).toBe("markdown");
|
||||
});
|
||||
|
||||
it("routes text/csv MIME to csv", () => {
|
||||
expect(classifyArtifact("text/csv", "noext").type).toBe("csv");
|
||||
});
|
||||
|
||||
it("routes TSV MIME to csv", () => {
|
||||
expect(classifyArtifact("text/tab-separated-values", "noext").type).toBe(
|
||||
"csv",
|
||||
);
|
||||
});
|
||||
|
||||
it("routes unknown text/* MIME to text (not download-only)", () => {
|
||||
expect(classifyArtifact("text/rtf", "noext").type).toBe("text");
|
||||
});
|
||||
|
||||
it("routes browser-safe image MIME types to image", () => {
|
||||
expect(classifyArtifact("image/avif", "noext").type).toBe("image");
|
||||
});
|
||||
|
||||
it("keeps unsupported image MIME types download-only", () => {
|
||||
for (const mime of [
|
||||
"image/tiff",
|
||||
"image/x-portable-pixmap",
|
||||
"image/x-portable-graymap",
|
||||
]) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("routes browser-safe video MIME types to video", () => {
|
||||
expect(classifyArtifact("video/mp4", "noext").type).toBe("video");
|
||||
expect(classifyArtifact("video/webm", "noext").type).toBe("video");
|
||||
});
|
||||
|
||||
it("keeps legacy or unsupported video MIME types download-only", () => {
|
||||
for (const mime of [
|
||||
"video/x-msvideo",
|
||||
"video/x-flv",
|
||||
"video/mpeg",
|
||||
"video/quicktime",
|
||||
"video/x-matroska",
|
||||
"video/ogg",
|
||||
]) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
// ── BINARY_MIMES coverage ────────────────────────────────────────
|
||||
|
||||
it("treats all BINARY_MIMES entries as download-only", () => {
|
||||
const binaryMimes = [
|
||||
"application/zip",
|
||||
"application/x-zip-compressed",
|
||||
"application/gzip",
|
||||
"application/x-tar",
|
||||
"application/x-rar-compressed",
|
||||
"application/x-7z-compressed",
|
||||
"application/octet-stream",
|
||||
"application/x-executable",
|
||||
"application/x-msdos-program",
|
||||
"application/vnd.microsoft.portable-executable",
|
||||
];
|
||||
for (const mime of binaryMimes) {
|
||||
const c = classifyArtifact(mime, "noext");
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
}
|
||||
});
|
||||
|
||||
it("treats audio/* MIME as download-only", () => {
|
||||
expect(classifyArtifact("audio/mpeg", "noext").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/wav", "noext").openable).toBe(false);
|
||||
expect(classifyArtifact("audio/ogg", "noext").openable).toBe(false);
|
||||
});
|
||||
|
||||
// ── Size gate edge cases ──────────────────────────────────────────
|
||||
|
||||
it("does NOT gate files at exactly 10MB (boundary is >10MB)", () => {
|
||||
const tenMB = 10 * 1024 * 1024;
|
||||
const c = classifyArtifact("text/plain", "exact.txt", tenMB);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("gates files at 10MB + 1 byte", () => {
|
||||
const overTenMB = 10 * 1024 * 1024 + 1;
|
||||
const c = classifyArtifact("text/plain", "big.txt", overTenMB);
|
||||
expect(c.type).toBe("download-only");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not gate when sizeBytes is 0", () => {
|
||||
const c = classifyArtifact("text/plain", "empty.txt", 0);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
it("does not gate when sizeBytes is undefined", () => {
|
||||
const c = classifyArtifact("text/plain", "file.txt", undefined);
|
||||
expect(c.type).toBe("text");
|
||||
expect(c.openable).toBe(true);
|
||||
});
|
||||
|
||||
// ── Extension over MIME priority ──────────────────────────────────
|
||||
|
||||
it("extension wins over MIME for JSON (MIME says text, ext says json)", () => {
|
||||
const c = classifyArtifact("text/plain", "data.json");
|
||||
expect(c.type).toBe("json");
|
||||
});
|
||||
|
||||
it("extension wins over MIME for markdown", () => {
|
||||
const c = classifyArtifact("text/plain", "README.md");
|
||||
expect(c.type).toBe("markdown");
|
||||
});
|
||||
|
||||
// ── Null/missing inputs ───────────────────────────────────────────
|
||||
|
||||
it("handles null MIME with no filename as download-only", () => {
|
||||
const c = classifyArtifact(null, undefined);
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("handles null MIME with empty filename as download-only", () => {
|
||||
const c = classifyArtifact(null, "");
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("handles known config files with no extension", () => {
|
||||
const c = classifyArtifact(null, "Makefile");
|
||||
expect(c.type).toBe("code");
|
||||
});
|
||||
|
||||
// ── Exotic/compound extensions must NOT open the side panel ───────
|
||||
// These are real file types agents might produce. Every single one
|
||||
// must be download-only so we never try to render binary garbage.
|
||||
|
||||
it("does not open .tar.gz (compound extension takes last segment)", () => {
|
||||
// getExtension("archive.tar.gz") → ".gz" which is not in EXT_KIND
|
||||
const c = classifyArtifact(null, "archive.tar.gz");
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
});
|
||||
|
||||
it("does not open .tar.bz2", () => {
|
||||
const c = classifyArtifact(null, "archive.tar.bz2");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open .tar.xz", () => {
|
||||
const c = classifyArtifact(null, "archive.tar.xz");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open common binary formats", () => {
|
||||
const binaries = [
|
||||
"setup.exe",
|
||||
"library.dll",
|
||||
"image.iso",
|
||||
"installer.dmg",
|
||||
"package.deb",
|
||||
"package.rpm",
|
||||
"module.wasm",
|
||||
"Main.class",
|
||||
"module.pyc",
|
||||
"app.apk",
|
||||
"game.pak",
|
||||
"model.onnx",
|
||||
"weights.pt",
|
||||
"data.parquet",
|
||||
"archive.rar",
|
||||
"archive.7z",
|
||||
"disk.vhd",
|
||||
"disk.vmdk",
|
||||
"firmware.bin",
|
||||
"core.dump",
|
||||
"database.sqlite",
|
||||
"database.db",
|
||||
"index.idx",
|
||||
];
|
||||
for (const file of binaries) {
|
||||
const c = classifyArtifact(null, file);
|
||||
expect(c.openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open binary MIME types even with a misleading extension", () => {
|
||||
// Extension is unknown, MIME is binary
|
||||
const c = classifyArtifact("application/x-executable", "run.elf");
|
||||
expect(c.openable).toBe(false);
|
||||
});
|
||||
|
||||
it("does not open files with random/made-up extensions", () => {
|
||||
const weirdExts = [
|
||||
"output.xyz",
|
||||
"data.foo",
|
||||
"file.asdf",
|
||||
"thing.blargh",
|
||||
"result.out",
|
||||
"x.1234",
|
||||
];
|
||||
for (const file of weirdExts) {
|
||||
const c = classifyArtifact(null, file);
|
||||
expect(c.openable).toBe(false);
|
||||
expect(c.type).toBe("download-only");
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open font files", () => {
|
||||
for (const file of ["sans.ttf", "serif.otf", "icon.woff", "icon.woff2"]) {
|
||||
expect(classifyArtifact(null, file).openable).toBe(false);
|
||||
}
|
||||
});
|
||||
|
||||
it("does not open certificate/key files", () => {
|
||||
// .pem and .key have no extension mapping and null MIME → download-only
|
||||
for (const file of ["cert.pem", "server.key", "ca.crt", "id.p12"]) {
|
||||
expect(classifyArtifact(null, file).openable).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
FileText,
|
||||
Image,
|
||||
Table,
|
||||
VideoCamera,
|
||||
} from "@phosphor-icons/react";
|
||||
import type { Icon } from "@phosphor-icons/react";
|
||||
|
||||
@@ -17,6 +18,7 @@ export interface ArtifactClassification {
|
||||
| "csv"
|
||||
| "json"
|
||||
| "image"
|
||||
| "video"
|
||||
| "pdf"
|
||||
| "text"
|
||||
| "download-only";
|
||||
@@ -38,6 +40,13 @@ const KIND: Record<string, ArtifactClassification> = {
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
},
|
||||
video: {
|
||||
type: "video",
|
||||
icon: VideoCamera,
|
||||
label: "Video",
|
||||
openable: true,
|
||||
hasSourceToggle: false,
|
||||
},
|
||||
pdf: {
|
||||
type: "pdf",
|
||||
icon: FileText,
|
||||
@@ -113,8 +122,13 @@ const EXT_KIND: Record<string, string> = {
|
||||
".svg": "image",
|
||||
".bmp": "image",
|
||||
".ico": "image",
|
||||
".avif": "image",
|
||||
".mp4": "video",
|
||||
".webm": "video",
|
||||
".m4v": "video",
|
||||
".pdf": "pdf",
|
||||
".csv": "csv",
|
||||
".tsv": "csv",
|
||||
".html": "html",
|
||||
".htm": "html",
|
||||
".jsx": "react",
|
||||
@@ -122,11 +136,17 @@ const EXT_KIND: Record<string, string> = {
|
||||
".md": "markdown",
|
||||
".mdx": "markdown",
|
||||
".json": "json",
|
||||
".jsonl": "code",
|
||||
".txt": "text",
|
||||
".log": "text",
|
||||
".ics": "text",
|
||||
".vcf": "text",
|
||||
".env": "code",
|
||||
".gitignore": "code",
|
||||
// code extensions
|
||||
".js": "code",
|
||||
".ts": "code",
|
||||
".dart": "code",
|
||||
".py": "code",
|
||||
".rb": "code",
|
||||
".go": "code",
|
||||
@@ -142,11 +162,19 @@ const EXT_KIND: Record<string, string> = {
|
||||
".sh": "code",
|
||||
".bash": "code",
|
||||
".zsh": "code",
|
||||
".scss": "code",
|
||||
".sass": "code",
|
||||
".less": "code",
|
||||
".graphql": "code",
|
||||
".gql": "code",
|
||||
".proto": "code",
|
||||
".yml": "code",
|
||||
".yaml": "code",
|
||||
".toml": "code",
|
||||
".ini": "code",
|
||||
".cfg": "code",
|
||||
".conf": "code",
|
||||
".properties": "code",
|
||||
".sql": "code",
|
||||
".r": "code",
|
||||
".lua": "code",
|
||||
@@ -154,10 +182,16 @@ const EXT_KIND: Record<string, string> = {
|
||||
".scala": "code",
|
||||
};
|
||||
|
||||
const EXACT_FILENAME_KIND: Record<string, string> = {
|
||||
dockerfile: "code",
|
||||
makefile: "code",
|
||||
};
|
||||
|
||||
// Exact-match MIME → kind (fallback when extension doesn't match).
|
||||
const MIME_KIND: Record<string, string> = {
|
||||
"application/pdf": "pdf",
|
||||
"text/csv": "csv",
|
||||
"text/tab-separated-values": "csv",
|
||||
"text/html": "html",
|
||||
"text/jsx": "react",
|
||||
"text/tsx": "react",
|
||||
@@ -166,6 +200,9 @@ const MIME_KIND: Record<string, string> = {
|
||||
"text/markdown": "markdown",
|
||||
"text/x-markdown": "markdown",
|
||||
"application/json": "json",
|
||||
"application/x-ndjson": "code",
|
||||
"application/ndjson": "code",
|
||||
"application/jsonl": "code",
|
||||
"application/javascript": "code",
|
||||
"text/javascript": "code",
|
||||
"application/typescript": "code",
|
||||
@@ -182,11 +219,37 @@ const BINARY_MIMES = new Set([
|
||||
"application/x-rar-compressed",
|
||||
"application/x-7z-compressed",
|
||||
"application/octet-stream",
|
||||
"application/wasm",
|
||||
"application/x-executable",
|
||||
"application/x-msdos-program",
|
||||
"application/vnd.microsoft.portable-executable",
|
||||
]);
|
||||
|
||||
const PREVIEWABLE_IMAGE_MIMES = new Set([
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/webp",
|
||||
"image/svg+xml",
|
||||
"image/bmp",
|
||||
"image/x-icon",
|
||||
"image/vnd.microsoft.icon",
|
||||
"image/avif",
|
||||
]);
|
||||
|
||||
const PREVIEWABLE_VIDEO_MIMES = new Set([
|
||||
"video/mp4",
|
||||
"video/webm",
|
||||
"video/x-m4v",
|
||||
]);
|
||||
|
||||
function getBasename(filename?: string): string {
|
||||
if (!filename) return "";
|
||||
const normalized = filename.replace(/\\/g, "/");
|
||||
const parts = normalized.split("/");
|
||||
return parts[parts.length - 1]?.toLowerCase() ?? "";
|
||||
}
|
||||
|
||||
function getExtension(filename?: string): string {
|
||||
if (!filename) return "";
|
||||
const lastDot = filename.lastIndexOf(".");
|
||||
@@ -202,24 +265,36 @@ export function classifyArtifact(
|
||||
// Size gate: >10MB is download-only regardless of type.
|
||||
if (sizeBytes && sizeBytes > TEN_MB) return KIND["download-only"];
|
||||
|
||||
const basename = getBasename(filename);
|
||||
const exactKind = EXACT_FILENAME_KIND[basename];
|
||||
if (exactKind) return KIND[exactKind];
|
||||
|
||||
if (basename === ".env" || basename.startsWith(".env.")) {
|
||||
return KIND.code;
|
||||
}
|
||||
|
||||
// Extension first (more reliable than MIME for AI-generated files).
|
||||
const ext = getExtension(filename);
|
||||
const ext = getExtension(basename);
|
||||
const extKind = EXT_KIND[ext];
|
||||
if (extKind) return KIND[extKind];
|
||||
|
||||
// MIME fallbacks.
|
||||
const mime = (mimeType ?? "").toLowerCase();
|
||||
if (mime.startsWith("image/")) return KIND.image;
|
||||
if (PREVIEWABLE_IMAGE_MIMES.has(mime)) return KIND.image;
|
||||
if (PREVIEWABLE_VIDEO_MIMES.has(mime)) return KIND.video;
|
||||
const mimeKind = MIME_KIND[mime];
|
||||
if (mimeKind) return KIND[mimeKind];
|
||||
if (mime.startsWith("text/x-")) return KIND.code;
|
||||
if (
|
||||
BINARY_MIMES.has(mime) ||
|
||||
mime.startsWith("audio/") ||
|
||||
mime.startsWith("video/")
|
||||
mime.startsWith("image/") ||
|
||||
mime.startsWith("video/") ||
|
||||
mime.startsWith("font/")
|
||||
) {
|
||||
return KIND["download-only"];
|
||||
}
|
||||
if (BINARY_MIMES.has(mime) || mime.startsWith("audio/")) {
|
||||
return KIND["download-only"];
|
||||
}
|
||||
if (mime.startsWith("text/")) return KIND.text;
|
||||
|
||||
// Unknown extension + unknown MIME: don't open — we can't safely assume
|
||||
|
||||
@@ -83,6 +83,7 @@ export function useArtifactPanel() {
|
||||
const canCopy =
|
||||
classification != null &&
|
||||
classification.type !== "image" &&
|
||||
classification.type !== "video" &&
|
||||
classification.type !== "download-only" &&
|
||||
classification.type !== "pdf";
|
||||
|
||||
|
||||
@@ -64,10 +64,7 @@ export const ChatContainer = ({
|
||||
// open state drive layout width; an artifact generated in a stale session
|
||||
// state would otherwise shrink the chat column with no panel rendered.
|
||||
const isArtifactOpen = isArtifactsEnabled && isArtifactPanelOpen;
|
||||
useAutoOpenArtifacts({
|
||||
messages: isArtifactsEnabled ? messages : [],
|
||||
sessionId,
|
||||
});
|
||||
useAutoOpenArtifacts({ sessionId });
|
||||
const isBusy =
|
||||
status === "streaming" ||
|
||||
status === "submitted" ||
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import { describe, expect, it, beforeEach, afterEach } from "vitest";
|
||||
import { renderHook } from "@testing-library/react";
|
||||
import { useAutoOpenArtifacts } from "../useAutoOpenArtifacts";
|
||||
import { useCopilotUIStore } from "../../../store";
|
||||
|
||||
// Capture the real store actions before any test can replace them.
|
||||
const realOpenArtifact = useCopilotUIStore.getState().openArtifact;
|
||||
const realResetArtifactPanel = useCopilotUIStore.getState().resetArtifactPanel;
|
||||
|
||||
function resetStore() {
|
||||
useCopilotUIStore.setState({
|
||||
openArtifact: realOpenArtifact,
|
||||
resetArtifactPanel: realResetArtifactPanel,
|
||||
artifactPanel: {
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
width: 600,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
describe("useAutoOpenArtifacts", () => {
|
||||
beforeEach(resetStore);
|
||||
afterEach(resetStore);
|
||||
|
||||
it("does not auto-open artifacts on initial message load", () => {
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "session-1" }));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("does not auto-open when rerendering within the same session", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string }) =>
|
||||
useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "session-1" } },
|
||||
);
|
||||
|
||||
rerender({ sessionId: "session-1" });
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("panel should fully reset when session changes", () => {
|
||||
const artifact = {
|
||||
id: "file1",
|
||||
title: "image.png",
|
||||
mimeType: "image/png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file1/download",
|
||||
origin: "agent" as const,
|
||||
};
|
||||
useCopilotUIStore.getState().openArtifact(artifact);
|
||||
useCopilotUIStore.getState().openArtifact({
|
||||
...artifact,
|
||||
id: "file2",
|
||||
title: "second.png",
|
||||
sourceUrl: "/api/proxy/api/workspace/files/file2/download",
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }: { sessionId: string }) =>
|
||||
useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "session-1" } },
|
||||
);
|
||||
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(true);
|
||||
|
||||
rerender({ sessionId: "session-2" });
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
});
|
||||
@@ -3,17 +3,19 @@ import { beforeEach, describe, expect, it } from "vitest";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { useAutoOpenArtifacts } from "./useAutoOpenArtifacts";
|
||||
|
||||
function assistantMessageWithText(id: string, text: string) {
|
||||
return {
|
||||
id,
|
||||
role: "assistant" as const,
|
||||
parts: [{ type: "text" as const, text }],
|
||||
};
|
||||
}
|
||||
|
||||
const A_ID = "11111111-0000-0000-0000-000000000000";
|
||||
const B_ID = "22222222-0000-0000-0000-000000000000";
|
||||
|
||||
function makeArtifact(id: string, title = `${id}.txt`) {
|
||||
return {
|
||||
id,
|
||||
title,
|
||||
mimeType: "text/plain",
|
||||
sourceUrl: `/api/proxy/api/workspace/files/${id}/download`,
|
||||
origin: "agent" as const,
|
||||
};
|
||||
}
|
||||
|
||||
function resetStore() {
|
||||
useCopilotUIStore.setState({
|
||||
artifactPanel: {
|
||||
@@ -30,111 +32,60 @@ function resetStore() {
|
||||
describe("useAutoOpenArtifacts", () => {
|
||||
beforeEach(resetStore);
|
||||
|
||||
it("does NOT auto-open on the initial hydration of message list (baseline pass)", () => {
|
||||
const messages = [
|
||||
assistantMessageWithText("m1", `[a](workspace://${A_ID})`),
|
||||
];
|
||||
renderHook(() =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId: "s1" }),
|
||||
);
|
||||
// Initial run just records the baseline fingerprint; nothing opens.
|
||||
it("does not auto-open on initial render", () => {
|
||||
renderHook(() => useAutoOpenArtifacts({ sessionId: "s1" }));
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("auto-opens when an existing assistant message adds a new artifact", () => {
|
||||
// 1st render: baseline with no artifact.
|
||||
const initial = [assistantMessageWithText("m1", "thinking...")];
|
||||
it("does not auto-open when rerendering within the same session", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{ initialProps: { messages: initial, sessionId: "s1" } },
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
|
||||
// 2nd render: same message id now contains an artifact link.
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m1", `here: [A](workspace://${A_ID})`),
|
||||
],
|
||||
sessionId: "s1",
|
||||
});
|
||||
rerender({ sessionId: "s1" });
|
||||
});
|
||||
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("resets the panel state when sessionId changes", () => {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
|
||||
act(() => {
|
||||
rerender({ sessionId: "s2" });
|
||||
});
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe(A_ID);
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("does not re-open when the fingerprint hasn't changed", () => {
|
||||
const msg = assistantMessageWithText("m1", `[A](workspace://${A_ID})`);
|
||||
it("does not carry a stale back stack into the next session", () => {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(A_ID, "a.txt"));
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(B_ID, "b.txt"));
|
||||
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{ initialProps: { messages: [msg], sessionId: "s1" } },
|
||||
({ sessionId }) => useAutoOpenArtifacts({ sessionId }),
|
||||
{ initialProps: { sessionId: "s1" } },
|
||||
);
|
||||
// Baseline captured; no open.
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
|
||||
// Rerender identical content: no change in fingerprint → no open.
|
||||
act(() => {
|
||||
rerender({ messages: [msg], sessionId: "s1" });
|
||||
rerender({ sessionId: "s2" });
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
});
|
||||
|
||||
it("auto-opens when a brand-new assistant message arrives after the baseline is established", () => {
|
||||
// First render: one message without artifacts → establishes baseline.
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{
|
||||
initialProps: {
|
||||
messages: [assistantMessageWithText("m1", "plain")] as any,
|
||||
sessionId: "s1",
|
||||
},
|
||||
},
|
||||
);
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact("c", "c.txt"));
|
||||
|
||||
// Second render: a *new* assistant message with an artifact. Baseline
|
||||
// is already set, so this should auto-open.
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m1", "plain"),
|
||||
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
|
||||
] as any,
|
||||
sessionId: "s1",
|
||||
});
|
||||
});
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe(B_ID);
|
||||
});
|
||||
|
||||
it("resets hydration baseline when sessionId changes", () => {
|
||||
const { rerender } = renderHook(
|
||||
({ messages, sessionId }) =>
|
||||
useAutoOpenArtifacts({ messages: messages as any, sessionId }),
|
||||
{
|
||||
initialProps: {
|
||||
messages: [
|
||||
assistantMessageWithText("m1", `[A](workspace://${A_ID})`),
|
||||
] as any,
|
||||
sessionId: "s1",
|
||||
},
|
||||
},
|
||||
);
|
||||
// Switch to a new session — the first pass on the new session should
|
||||
// NOT auto-open (it's a fresh hydration).
|
||||
act(() => {
|
||||
rerender({
|
||||
messages: [
|
||||
assistantMessageWithText("m2", `[B](workspace://${B_ID})`),
|
||||
] as any,
|
||||
sessionId: "s2",
|
||||
});
|
||||
});
|
||||
expect(useCopilotUIStore.getState().artifactPanel.isOpen).toBe(false);
|
||||
expect(s.activeArtifact?.id).toBe("c");
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,91 +1,29 @@
|
||||
"use client";
|
||||
|
||||
import { UIDataTypes, UIMessage, UITools } from "ai";
|
||||
import { useEffect, useRef } from "react";
|
||||
import type { ArtifactRef } from "../../store";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
import { getMessageArtifacts } from "../ChatMessagesContainer/helpers";
|
||||
|
||||
function fingerprintArtifacts(artifacts: ArtifactRef[]): string {
|
||||
return artifacts
|
||||
.map((a) => `${a.id}:${a.title}:${a.mimeType ?? ""}:${a.sourceUrl}`)
|
||||
.join("|");
|
||||
}
|
||||
|
||||
interface UseAutoOpenArtifactsOptions {
|
||||
messages: UIMessage<unknown, UIDataTypes, UITools>[];
|
||||
sessionId: string | null;
|
||||
}
|
||||
|
||||
export function useAutoOpenArtifacts({
|
||||
messages,
|
||||
sessionId,
|
||||
}: UseAutoOpenArtifactsOptions) {
|
||||
const openArtifact = useCopilotUIStore((state) => state.openArtifact);
|
||||
const messageFingerprintsRef = useRef<Map<string, string>>(new Map());
|
||||
const hasInitializedRef = useRef(false);
|
||||
const resetArtifactPanel = useCopilotUIStore(
|
||||
(state) => state.resetArtifactPanel,
|
||||
);
|
||||
const prevSessionIdRef = useRef(sessionId);
|
||||
|
||||
useEffect(() => {
|
||||
messageFingerprintsRef.current = new Map();
|
||||
hasInitializedRef.current = false;
|
||||
}, [sessionId]);
|
||||
const isSessionChange = prevSessionIdRef.current !== sessionId;
|
||||
prevSessionIdRef.current = sessionId;
|
||||
|
||||
useEffect(() => {
|
||||
if (messages.length === 0) {
|
||||
messageFingerprintsRef.current = new Map();
|
||||
return;
|
||||
// Artifact previews should open only from an explicit user click.
|
||||
// When the session changes, fully clear the panel state so stale
|
||||
// active artifacts and back-stack entries never bleed into the next chat.
|
||||
if (isSessionChange) {
|
||||
resetArtifactPanel();
|
||||
}
|
||||
|
||||
// Only scan messages whose fingerprint might have changed since the
|
||||
// last pass: that's the last assistant message (currently streaming)
|
||||
// plus any assistant message whose id isn't in the baseline yet.
|
||||
// This keeps the cost O(new+tail), not O(all messages), on every chunk.
|
||||
const previous = messageFingerprintsRef.current;
|
||||
const nextFingerprints = new Map<string, string>(previous);
|
||||
let nextArtifact: ArtifactRef | null = null;
|
||||
const lastAssistantIdx = (() => {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
if (messages[i].role === "assistant") return i;
|
||||
}
|
||||
return -1;
|
||||
})();
|
||||
|
||||
for (let i = 0; i < messages.length; i++) {
|
||||
const message = messages[i];
|
||||
if (message.role !== "assistant") continue;
|
||||
const isTailAssistant = i === lastAssistantIdx;
|
||||
const isNewMessage = !previous.has(message.id);
|
||||
if (!isTailAssistant && !isNewMessage) continue;
|
||||
|
||||
const artifacts = getMessageArtifacts(message);
|
||||
const fingerprint = fingerprintArtifacts(artifacts);
|
||||
nextFingerprints.set(message.id, fingerprint);
|
||||
|
||||
if (!hasInitializedRef.current || fingerprint.length === 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const previousFingerprint = previous.get(message.id) ?? "";
|
||||
if (previousFingerprint === fingerprint) continue;
|
||||
|
||||
nextArtifact = artifacts[artifacts.length - 1] ?? nextArtifact;
|
||||
}
|
||||
|
||||
// Drop entries for messages that no longer exist (e.g. history truncated).
|
||||
const liveIds = new Set(messages.map((m) => m.id));
|
||||
for (const id of nextFingerprints.keys()) {
|
||||
if (!liveIds.has(id)) nextFingerprints.delete(id);
|
||||
}
|
||||
|
||||
messageFingerprintsRef.current = nextFingerprints;
|
||||
|
||||
if (!hasInitializedRef.current) {
|
||||
hasInitializedRef.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (nextArtifact) {
|
||||
openArtifact(nextArtifact);
|
||||
}
|
||||
}, [messages, openArtifact]);
|
||||
}, [sessionId, resetArtifactPanel]);
|
||||
}
|
||||
|
||||
@@ -19,8 +19,16 @@ describe("formatResetTime", () => {
|
||||
});
|
||||
|
||||
it("returns formatted date when over 24 hours away", () => {
|
||||
const result = formatResetTime("2025-06-17T00:00:00Z", now);
|
||||
expect(result).toMatch(/Tue/);
|
||||
const resetsAt = "2025-06-17T00:00:00Z";
|
||||
const result = formatResetTime(resetsAt, now);
|
||||
const expected = new Date(resetsAt).toLocaleString(undefined, {
|
||||
weekday: "short",
|
||||
hour: "numeric",
|
||||
minute: "2-digit",
|
||||
timeZoneName: "short",
|
||||
});
|
||||
|
||||
expect(result).toBe(expected);
|
||||
});
|
||||
|
||||
it("accepts a Date object for resetsAt", () => {
|
||||
|
||||
@@ -99,6 +99,50 @@ describe("artifactPanel store actions", () => {
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("openArtifact does not resurrect a previously closed artifact into history", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().closeArtifactPanel();
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(true);
|
||||
expect(s.activeArtifact?.id).toBe("b");
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("openArtifact ignores non-previewable artifacts", () => {
|
||||
const binary = {
|
||||
...makeArtifact("bin", "artifact.bin"),
|
||||
mimeType: "application/octet-stream",
|
||||
};
|
||||
|
||||
useCopilotUIStore.getState().openArtifact(binary);
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("resetArtifactPanel clears active artifact and history", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
useCopilotUIStore.getState().maximizeArtifactPanel();
|
||||
|
||||
useCopilotUIStore.getState().resetArtifactPanel();
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.isMinimized).toBe(false);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
});
|
||||
|
||||
it("minimize/restore toggles isMinimized without touching activeArtifact", () => {
|
||||
const a = makeArtifact("a");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
@@ -138,4 +182,35 @@ describe("artifactPanel store actions", () => {
|
||||
expect(s.width).toBe(720);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
});
|
||||
|
||||
it("history is capped at 25 entries (MAX_HISTORY)", () => {
|
||||
// Open 27 artifacts sequentially (A0..A26). History should never exceed 25.
|
||||
for (let i = 0; i < 27; i++) {
|
||||
useCopilotUIStore.getState().openArtifact(makeArtifact(`a${i}`));
|
||||
}
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.activeArtifact?.id).toBe("a26");
|
||||
expect(s.history.length).toBe(25);
|
||||
// The oldest entry (a0) should have been dropped; a1 is the earliest surviving.
|
||||
expect(s.history[0].id).toBe("a1");
|
||||
expect(s.history[24].id).toBe("a25");
|
||||
});
|
||||
|
||||
it("clearCopilotLocalData resets artifact panel to default", () => {
|
||||
const a = makeArtifact("a");
|
||||
const b = makeArtifact("b");
|
||||
useCopilotUIStore.getState().openArtifact(a);
|
||||
useCopilotUIStore.getState().openArtifact(b);
|
||||
useCopilotUIStore.getState().maximizeArtifactPanel();
|
||||
|
||||
useCopilotUIStore.getState().clearCopilotLocalData();
|
||||
|
||||
const s = useCopilotUIStore.getState().artifactPanel;
|
||||
expect(s.isOpen).toBe(false);
|
||||
expect(s.isMinimized).toBe(false);
|
||||
expect(s.isMaximized).toBe(false);
|
||||
expect(s.activeArtifact).toBeNull();
|
||||
expect(s.history).toEqual([]);
|
||||
expect(s.width).toBe(600); // DEFAULT_PANEL_WIDTH
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { create } from "zustand";
|
||||
import { clearContentCache } from "./components/ArtifactPanel/components/useArtifactContent";
|
||||
import { classifyArtifact } from "./components/ArtifactPanel/helpers";
|
||||
import { ORIGINAL_TITLE, parseSessionIDs } from "./helpers";
|
||||
|
||||
export interface DeleteTarget {
|
||||
@@ -92,6 +93,10 @@ function persistCompletedSessions(ids: Set<string>) {
|
||||
}
|
||||
}
|
||||
|
||||
function isPreviewableArtifact(ref: ArtifactRef): boolean {
|
||||
return classifyArtifact(ref.mimeType, ref.title, ref.sizeBytes).openable;
|
||||
}
|
||||
|
||||
interface CopilotUIState {
|
||||
/** Prompt extracted from URL hash (e.g. /copilot#prompt=...) for input prefill. */
|
||||
initialPrompt: string | null;
|
||||
@@ -121,6 +126,7 @@ interface CopilotUIState {
|
||||
artifactPanel: ArtifactPanelState;
|
||||
openArtifact: (ref: ArtifactRef) => void;
|
||||
closeArtifactPanel: () => void;
|
||||
resetArtifactPanel: () => void;
|
||||
minimizeArtifactPanel: () => void;
|
||||
maximizeArtifactPanel: () => void;
|
||||
restoreArtifactPanel: () => void;
|
||||
@@ -203,14 +209,20 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
},
|
||||
openArtifact: (ref) =>
|
||||
set((state) => {
|
||||
if (!isPreviewableArtifact(ref)) return state;
|
||||
|
||||
const { activeArtifact, history: prevHistory } = state.artifactPanel;
|
||||
const topOfHistory = prevHistory[prevHistory.length - 1];
|
||||
const isReturningToTop = topOfHistory?.id === ref.id;
|
||||
const shouldPushHistory =
|
||||
state.artifactPanel.isOpen &&
|
||||
activeArtifact != null &&
|
||||
activeArtifact.id !== ref.id;
|
||||
const MAX_HISTORY = 25;
|
||||
const history = isReturningToTop
|
||||
? prevHistory.slice(0, -1)
|
||||
: activeArtifact && activeArtifact.id !== ref.id
|
||||
? [...prevHistory, activeArtifact].slice(-MAX_HISTORY)
|
||||
: shouldPushHistory
|
||||
? [...prevHistory, activeArtifact!].slice(-MAX_HISTORY)
|
||||
: prevHistory;
|
||||
return {
|
||||
artifactPanel: {
|
||||
@@ -231,6 +243,17 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
|
||||
history: [],
|
||||
},
|
||||
})),
|
||||
resetArtifactPanel: () =>
|
||||
set((state) => ({
|
||||
artifactPanel: {
|
||||
...state.artifactPanel,
|
||||
isOpen: false,
|
||||
isMinimized: false,
|
||||
isMaximized: false,
|
||||
activeArtifact: null,
|
||||
history: [],
|
||||
},
|
||||
})),
|
||||
minimizeArtifactPanel: () =>
|
||||
set((state) => ({
|
||||
artifactPanel: { ...state.artifactPanel, isMinimized: true },
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import { Button } from "@/components/atoms/Button/Button";
|
||||
import type { BlockOutputResponse } from "@/app/api/__generated__/models/blockOutputResponse";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
} from "@/components/contextual/OutputRenderers";
|
||||
import type { OutputMetadata } from "@/components/contextual/OutputRenderers";
|
||||
import { isWorkspaceURI, parseWorkspaceURI } from "@/lib/workspace-uri";
|
||||
import { resolveForRenderer } from "@/app/(platform)/copilot/tools/ViewAgentOutput/ViewAgentOutput";
|
||||
import {
|
||||
ContentBadge,
|
||||
ContentCard,
|
||||
@@ -24,28 +22,6 @@ interface Props {
|
||||
|
||||
const COLLAPSED_LIMIT = 3;
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
if (!isWorkspaceURI(value)) return { value };
|
||||
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
const metadata: OutputMetadata = {};
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
}
|
||||
|
||||
function RenderOutputValue({ value }: { value: unknown }) {
|
||||
const resolved = resolveForRenderer(value);
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
@@ -63,16 +39,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import type { ToolUIPart } from "ai";
|
||||
import React from "react";
|
||||
import { getGetWorkspaceDownloadFileByIdUrl } from "@/app/api/__generated__/endpoints/workspace/workspace";
|
||||
import {
|
||||
globalRegistry,
|
||||
OutputItem,
|
||||
@@ -47,7 +46,7 @@ interface Props {
|
||||
part: ViewAgentOutputToolPart;
|
||||
}
|
||||
|
||||
function resolveForRenderer(value: unknown): {
|
||||
export function resolveForRenderer(value: unknown): {
|
||||
value: unknown;
|
||||
metadata?: OutputMetadata;
|
||||
} {
|
||||
@@ -56,17 +55,17 @@ function resolveForRenderer(value: unknown): {
|
||||
const parsed = parseWorkspaceURI(value);
|
||||
if (!parsed) return { value };
|
||||
|
||||
const apiPath = getGetWorkspaceDownloadFileByIdUrl(parsed.fileID);
|
||||
const url = `/api/proxy${apiPath}`;
|
||||
|
||||
// Pass workspace URIs through to the registry unchanged.
|
||||
// WorkspaceFileRenderer (priority 50) matches workspace:// URIs and
|
||||
// handles URL building, loading skeletons, and error states internally.
|
||||
// Previously this converted to a proxy URL which bypassed
|
||||
// WorkspaceFileRenderer, causing ImageRenderer (bare <img>) to match.
|
||||
const metadata: OutputMetadata = {};
|
||||
if (parsed.mimeType) {
|
||||
metadata.mimeType = parsed.mimeType;
|
||||
if (parsed.mimeType.startsWith("image/")) metadata.type = "image";
|
||||
else if (parsed.mimeType.startsWith("video/")) metadata.type = "video";
|
||||
}
|
||||
|
||||
return { value: url, metadata };
|
||||
return { value, metadata };
|
||||
}
|
||||
|
||||
function RenderOutputValue({ value }: { value: unknown }) {
|
||||
@@ -86,16 +85,6 @@ function RenderOutputValue({ value }: { value: unknown }) {
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback for audio workspace refs
|
||||
if (
|
||||
isWorkspaceURI(value) &&
|
||||
resolved.metadata?.mimeType?.startsWith("audio/")
|
||||
) {
|
||||
return (
|
||||
<audio controls src={String(resolved.value)} className="mt-2 w-full" />
|
||||
);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { resolveForRenderer } from "../ViewAgentOutput";
|
||||
import { globalRegistry } from "@/components/contextual/OutputRenderers";
|
||||
|
||||
describe("resolveForRenderer", () => {
|
||||
it("preserves workspace image URI for the registry to handle", () => {
|
||||
const result = resolveForRenderer("workspace://abc123#image/png");
|
||||
expect(String(result.value)).toMatch(/^workspace:\/\//);
|
||||
expect(result.metadata?.mimeType).toBe("image/png");
|
||||
});
|
||||
|
||||
it("preserves workspace video URI for the registry to handle", () => {
|
||||
const result = resolveForRenderer("workspace://vid456#video/mp4");
|
||||
expect(String(result.value)).toMatch(/^workspace:\/\//);
|
||||
expect(result.metadata?.mimeType).toBe("video/mp4");
|
||||
});
|
||||
|
||||
it("passes non-workspace values through unchanged", () => {
|
||||
const result = resolveForRenderer("just a string");
|
||||
expect(result.value).toBe("just a string");
|
||||
expect(result.metadata).toBeUndefined();
|
||||
});
|
||||
|
||||
it("passes non-string values through unchanged", () => {
|
||||
const obj = { foo: "bar" };
|
||||
const result = resolveForRenderer(obj);
|
||||
expect(result.value).toBe(obj);
|
||||
expect(result.metadata).toBeUndefined();
|
||||
});
|
||||
|
||||
it("workspace image URIs match WorkspaceFileRenderer with loading/error states", () => {
|
||||
// WorkspaceFileRenderer (priority 50) should handle workspace:// URIs
|
||||
// since resolveForRenderer no longer pre-converts them to proxy URLs.
|
||||
const resolved = resolveForRenderer("workspace://abc123#image/png");
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
resolved.value,
|
||||
resolved.metadata,
|
||||
);
|
||||
expect(renderer).toBeDefined();
|
||||
expect(renderer!.name).toBe("WorkspaceFileRenderer");
|
||||
});
|
||||
|
||||
it("workspace video URIs match WorkspaceFileRenderer", () => {
|
||||
const resolved = resolveForRenderer("workspace://vid456#video/mp4");
|
||||
const renderer = globalRegistry.getRenderer(
|
||||
resolved.value,
|
||||
resolved.metadata,
|
||||
);
|
||||
expect(renderer).toBeDefined();
|
||||
expect(renderer!.name).toBe("WorkspaceFileRenderer");
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,96 @@
|
||||
import {
|
||||
getGetV2GetSpecificAgentMockHandler,
|
||||
getGetV2GetSpecificAgentResponseMock,
|
||||
getGetV2ListStoreAgentsMockHandler,
|
||||
getGetV2ListStoreAgentsResponseMock,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import { render, screen } from "@/tests/integrations/test-utils";
|
||||
import { MainAgentPage } from "../MainAgentPage";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseSupabase = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: mockUseSupabase,
|
||||
}));
|
||||
|
||||
describe("MainAgentPage", () => {
|
||||
beforeEach(() => {
|
||||
mockUseSupabase.mockReturnValue({
|
||||
user: null,
|
||||
});
|
||||
});
|
||||
|
||||
test("renders the marketplace agent details and related sections", async () => {
|
||||
const agentDetails = getGetV2GetSpecificAgentResponseMock({
|
||||
agent_name: "Deterministic Agent",
|
||||
creator: "AutoGPT",
|
||||
creator_avatar: "",
|
||||
sub_heading: "A stable marketplace listing",
|
||||
description: "This agent is used for integration coverage.",
|
||||
categories: ["demo", "test"],
|
||||
versions: ["1", "2"],
|
||||
active_version_id: "store-version-1",
|
||||
store_listing_version_id: "listing-1",
|
||||
agent_image: ["https://example.com/agent.png"],
|
||||
agent_output_demo: "",
|
||||
agent_video: "",
|
||||
});
|
||||
const otherAgents = getGetV2ListStoreAgentsResponseMock({
|
||||
agents: [
|
||||
{
|
||||
...getGetV2ListStoreAgentsResponseMock().agents[0],
|
||||
slug: "other-agent",
|
||||
agent_name: "Other Agent",
|
||||
creator: "AutoGPT",
|
||||
},
|
||||
],
|
||||
});
|
||||
const similarAgents = getGetV2ListStoreAgentsResponseMock({
|
||||
agents: [
|
||||
{
|
||||
...getGetV2ListStoreAgentsResponseMock().agents[0],
|
||||
slug: "similar-agent",
|
||||
agent_name: "Similar Agent",
|
||||
creator: "Another Creator",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
server.use(
|
||||
getGetV2GetSpecificAgentMockHandler(agentDetails),
|
||||
getGetV2ListStoreAgentsMockHandler(({ request }) => {
|
||||
const url = new URL(request.url);
|
||||
|
||||
if (url.searchParams.get("creator") === "autogpt") {
|
||||
return otherAgents;
|
||||
}
|
||||
|
||||
if (url.searchParams.get("search_query") === "deterministic agent") {
|
||||
return similarAgents;
|
||||
}
|
||||
|
||||
return getGetV2ListStoreAgentsResponseMock({ agents: [] });
|
||||
}),
|
||||
);
|
||||
|
||||
render(
|
||||
<MainAgentPage
|
||||
params={{ creator: "autogpt", slug: "deterministic-agent" }}
|
||||
/>,
|
||||
);
|
||||
|
||||
expect((await screen.findByTestId("agent-title")).textContent).toContain(
|
||||
"Deterministic Agent",
|
||||
);
|
||||
expect(screen.getByTestId("agent-description").textContent).toContain(
|
||||
"This agent is used for integration coverage.",
|
||||
);
|
||||
expect(screen.getByTestId("agent-creator").textContent).toContain(
|
||||
"AutoGPT",
|
||||
);
|
||||
expect(screen.getByText("Other agents by AutoGPT")).toBeDefined();
|
||||
expect(screen.getByText("Similar agents")).toBeDefined();
|
||||
});
|
||||
});
|
||||
@@ -1,15 +1,64 @@
|
||||
import { expect, test } from "vitest";
|
||||
import {
|
||||
getGetV2ListStoreAgentsResponseMock,
|
||||
getGetV2ListStoreCreatorsResponseMock,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { render, screen } from "@/tests/integrations/test-utils";
|
||||
import { MainMarkeplacePage } from "../MainMarketplacePage";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import { getDeleteV2DeleteStoreSubmissionMockHandler422 } from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
// Only for CI testing purpose, will remove it in future PR
|
||||
test("MainMarketplacePage", async () => {
|
||||
server.use(getDeleteV2DeleteStoreSubmissionMockHandler422());
|
||||
const mockUseMainMarketplacePage = vi.hoisted(() => vi.fn());
|
||||
|
||||
render(<MainMarkeplacePage />);
|
||||
expect(
|
||||
await screen.findByText("Featured agents", { exact: false }),
|
||||
).toBeDefined();
|
||||
vi.mock("../useMainMarketplacePage", () => ({
|
||||
useMainMarketplacePage: mockUseMainMarketplacePage,
|
||||
}));
|
||||
|
||||
describe("MainMarketplacePage", () => {
|
||||
beforeEach(() => {
|
||||
mockUseMainMarketplacePage.mockReturnValue({
|
||||
featuredAgents: getGetV2ListStoreAgentsResponseMock({
|
||||
agents: [
|
||||
{
|
||||
...getGetV2ListStoreAgentsResponseMock().agents[0],
|
||||
slug: "featured-agent",
|
||||
agent_name: "Featured Agent",
|
||||
creator: "AutoGPT",
|
||||
},
|
||||
],
|
||||
}),
|
||||
topAgents: getGetV2ListStoreAgentsResponseMock({
|
||||
agents: [
|
||||
{
|
||||
...getGetV2ListStoreAgentsResponseMock().agents[0],
|
||||
slug: "top-agent",
|
||||
agent_name: "Top Agent",
|
||||
creator: "AutoGPT",
|
||||
},
|
||||
],
|
||||
}),
|
||||
featuredCreators: getGetV2ListStoreCreatorsResponseMock({
|
||||
creators: [
|
||||
{
|
||||
...getGetV2ListStoreCreatorsResponseMock().creators[0],
|
||||
name: "Creator One",
|
||||
username: "creator-one",
|
||||
},
|
||||
],
|
||||
}),
|
||||
isLoading: false,
|
||||
hasError: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("renders featured agents, all agents, and creators", () => {
|
||||
render(<MainMarkeplacePage />);
|
||||
|
||||
expect(screen.getByText(/Featured agents/i)).toBeDefined();
|
||||
expect(screen.getByText("Featured Agent")).toBeDefined();
|
||||
expect(screen.getByText("All Agents")).toBeDefined();
|
||||
expect(screen.getAllByText("Top Agent").length).toBeGreaterThan(0);
|
||||
expect(screen.getByText("Creator One")).toBeDefined();
|
||||
expect(
|
||||
screen.getByRole("button", { name: "Become a Creator" }),
|
||||
).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
import { render, screen } from "@/tests/integrations/test-utils";
|
||||
import {
|
||||
getGetV2GetCreatorDetailsResponseMock,
|
||||
getGetV2ListStoreAgentsResponseMock,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { MainCreatorPage } from "../MainCreatorPage";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseMainCreatorPage = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("../useMainCreatorPage", () => ({
|
||||
useMainCreatorPage: mockUseMainCreatorPage,
|
||||
}));
|
||||
|
||||
describe("MainCreatorPage", () => {
|
||||
beforeEach(() => {
|
||||
const creator = getGetV2GetCreatorDetailsResponseMock({
|
||||
name: "Creator One",
|
||||
username: "creator-one",
|
||||
description: "Creator profile used for integration coverage.",
|
||||
avatar_url: "",
|
||||
top_categories: ["automation", "productivity"],
|
||||
links: ["https://example.com/creator"],
|
||||
});
|
||||
|
||||
const creatorAgents = getGetV2ListStoreAgentsResponseMock({
|
||||
agents: [
|
||||
{
|
||||
...getGetV2ListStoreAgentsResponseMock().agents[0],
|
||||
slug: "creator-agent",
|
||||
agent_name: "Creator Agent",
|
||||
creator: "Creator One",
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
mockUseMainCreatorPage.mockReturnValue({
|
||||
creatorAgents,
|
||||
creator,
|
||||
isLoading: false,
|
||||
hasError: false,
|
||||
});
|
||||
});
|
||||
|
||||
test("renders creator details and their agents", () => {
|
||||
render(<MainCreatorPage params={{ creator: "creator-one" }} />);
|
||||
|
||||
expect(screen.getByTestId("creator-title").textContent).toContain(
|
||||
"Creator One",
|
||||
);
|
||||
expect(screen.getByTestId("creator-description").textContent).toContain(
|
||||
"Creator profile used for integration coverage.",
|
||||
);
|
||||
expect(screen.getByText("Agents by Creator One")).toBeDefined();
|
||||
expect(screen.getAllByText("Creator Agent").length).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,83 @@
|
||||
import type { ReactNode } from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import {
|
||||
getGetV2GetUserProfileMockHandler,
|
||||
getPostV2UpdateUserProfileMockHandler,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import UserProfilePage from "../page";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseSupabase = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
|
||||
default: ({ children }: { children: ReactNode }) => <>{children}</>,
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: mockUseSupabase,
|
||||
}));
|
||||
|
||||
const testUser = {
|
||||
id: "user-1",
|
||||
email: "user@example.com",
|
||||
app_metadata: {},
|
||||
user_metadata: {},
|
||||
aud: "authenticated",
|
||||
created_at: "2026-01-01T00:00:00.000Z",
|
||||
};
|
||||
|
||||
describe("UserProfilePage", () => {
|
||||
beforeEach(() => {
|
||||
mockUseSupabase.mockReturnValue({
|
||||
user: testUser,
|
||||
isLoggedIn: true,
|
||||
isUserLoading: false,
|
||||
supabase: {},
|
||||
});
|
||||
});
|
||||
|
||||
test("renders the existing profile and saves changes", async () => {
|
||||
let profile = {
|
||||
name: "Original Name",
|
||||
username: "original-user",
|
||||
description: "Original bio",
|
||||
links: ["https://example.com/1", "", "", "", ""],
|
||||
avatar_url: "",
|
||||
is_featured: false,
|
||||
};
|
||||
|
||||
server.use(
|
||||
getGetV2GetUserProfileMockHandler(() => profile),
|
||||
getPostV2UpdateUserProfileMockHandler(async ({ request }) => {
|
||||
profile = (await request.json()) as typeof profile;
|
||||
return profile;
|
||||
}),
|
||||
);
|
||||
|
||||
render(<UserProfilePage />);
|
||||
|
||||
const displayName = await screen.findByLabelText("Display name");
|
||||
const handle = screen.getByLabelText("Handle");
|
||||
const bio = screen.getByLabelText("Bio");
|
||||
|
||||
expect((displayName as HTMLInputElement).value).toBe("Original Name");
|
||||
expect((handle as HTMLInputElement).value).toBe("original-user");
|
||||
|
||||
fireEvent.change(displayName, { target: { value: "Updated Name" } });
|
||||
fireEvent.change(handle, { target: { value: "updated-user" } });
|
||||
fireEvent.change(bio, { target: { value: "Updated bio" } });
|
||||
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(profile.name).toBe("Updated Name");
|
||||
expect(profile.username).toBe("updated-user");
|
||||
expect(profile.description).toBe("Updated bio");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,138 @@
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import {
|
||||
getDeleteV1RevokeApiKeyMockHandler,
|
||||
getGetV1ListUserApiKeysMockHandler,
|
||||
getPostV1CreateNewApiKeyMockHandler,
|
||||
} from "@/app/api/__generated__/endpoints/api-keys/api-keys.msw";
|
||||
import { APIKeyPermission } from "@/app/api/__generated__/models/aPIKeyPermission";
|
||||
import { APIKeyStatus } from "@/app/api/__generated__/models/aPIKeyStatus";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import ApiKeysPage from "../page";
|
||||
import { beforeEach, describe, expect, test } from "vitest";
|
||||
|
||||
type ApiKeyRecord = {
|
||||
id: string;
|
||||
name: string;
|
||||
head: string;
|
||||
tail: string;
|
||||
status: APIKeyStatus;
|
||||
};
|
||||
|
||||
function toApiKeyResponse(key: ApiKeyRecord) {
|
||||
return {
|
||||
id: key.id,
|
||||
user_id: "user-1",
|
||||
scopes: [APIKeyPermission.EXECUTE_GRAPH],
|
||||
type: "api_key" as const,
|
||||
created_at: new Date("2026-01-01T00:00:00.000Z"),
|
||||
expires_at: null,
|
||||
last_used_at: null,
|
||||
revoked_at: null,
|
||||
name: key.name,
|
||||
head: key.head,
|
||||
tail: key.tail,
|
||||
status: key.status,
|
||||
description: null,
|
||||
};
|
||||
}
|
||||
|
||||
describe("ApiKeysPage", () => {
|
||||
let apiKeys: ApiKeyRecord[];
|
||||
let revokedKeyId: string;
|
||||
|
||||
beforeEach(() => {
|
||||
apiKeys = [];
|
||||
revokedKeyId = "";
|
||||
|
||||
server.use(
|
||||
getGetV1ListUserApiKeysMockHandler(() =>
|
||||
apiKeys.map((key) => toApiKeyResponse(key)),
|
||||
),
|
||||
getPostV1CreateNewApiKeyMockHandler(async ({ request }) => {
|
||||
const body = (await request.json()) as {
|
||||
name: string;
|
||||
description?: string;
|
||||
permissions?: APIKeyPermission[];
|
||||
};
|
||||
|
||||
const createdKey: ApiKeyRecord = {
|
||||
id: `key-${apiKeys.length + 1}`,
|
||||
name: body.name,
|
||||
head: "head",
|
||||
tail: "tail",
|
||||
status: APIKeyStatus.ACTIVE,
|
||||
};
|
||||
|
||||
apiKeys = [...apiKeys, createdKey];
|
||||
|
||||
return {
|
||||
api_key: toApiKeyResponse(createdKey),
|
||||
plain_text_key: "plain-text-key",
|
||||
};
|
||||
}),
|
||||
getDeleteV1RevokeApiKeyMockHandler(({ params }) => {
|
||||
const keyId = String(params.keyId);
|
||||
const removedKey = apiKeys.find((key) => key.id === keyId);
|
||||
|
||||
revokedKeyId = keyId;
|
||||
apiKeys = apiKeys.filter((key) => key.id !== keyId);
|
||||
|
||||
return toApiKeyResponse(
|
||||
removedKey ?? {
|
||||
id: keyId,
|
||||
name: "Unknown key",
|
||||
head: "head",
|
||||
tail: "tail",
|
||||
status: APIKeyStatus.REVOKED,
|
||||
},
|
||||
);
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test("creates a new API key", async () => {
|
||||
render(<ApiKeysPage />);
|
||||
|
||||
fireEvent.click(await screen.findByText("Create Key"));
|
||||
fireEvent.change(screen.getByLabelText("Name"), {
|
||||
target: { value: "CLI Key" },
|
||||
});
|
||||
fireEvent.click(screen.getByText("Create"));
|
||||
|
||||
expect(
|
||||
await screen.findByText("AutoGPT Platform API Key Created"),
|
||||
).toBeDefined();
|
||||
|
||||
await waitFor(() => {
|
||||
expect(apiKeys[0]?.name).toBe("CLI Key");
|
||||
});
|
||||
});
|
||||
|
||||
test("revokes an existing API key", async () => {
|
||||
apiKeys = [
|
||||
{
|
||||
id: "key-1",
|
||||
name: "Existing Key",
|
||||
head: "head",
|
||||
tail: "tail",
|
||||
status: APIKeyStatus.ACTIVE,
|
||||
},
|
||||
];
|
||||
|
||||
render(<ApiKeysPage />);
|
||||
|
||||
expect(await screen.findByText("Existing Key")).toBeDefined();
|
||||
|
||||
fireEvent.pointerDown(screen.getByTestId("api-key-actions"));
|
||||
fireEvent.click(await screen.findByRole("menuitem", { name: "Revoke" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(revokedKeyId).toBe("key-1");
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,76 @@
|
||||
import { render, screen, fireEvent } from "@testing-library/react";
|
||||
import { getGetV2ListMySubmissionsResponseMock } from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { SubmissionStatus } from "@/app/api/__generated__/models/submissionStatus";
|
||||
import { AgentTableRow } from "../AgentTableRow";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
function makeSubmission(status: SubmissionStatus) {
|
||||
const submission = getGetV2ListMySubmissionsResponseMock().submissions[0];
|
||||
|
||||
return {
|
||||
...submission,
|
||||
graph_id: "graph-1",
|
||||
graph_version: 7,
|
||||
listing_version_id: `listing-${status.toLowerCase()}`,
|
||||
name: `Agent ${status}`,
|
||||
description: `Description ${status}`,
|
||||
status,
|
||||
image_urls: [],
|
||||
submitted_at: new Date("2026-01-01T00:00:00.000Z"),
|
||||
};
|
||||
}
|
||||
|
||||
describe("AgentTableRow", () => {
|
||||
const onViewSubmission = vi.fn();
|
||||
const onDeleteSubmission = vi.fn();
|
||||
const onEditSubmission = vi.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
onViewSubmission.mockReset();
|
||||
onDeleteSubmission.mockReset();
|
||||
onEditSubmission.mockReset();
|
||||
});
|
||||
|
||||
test("shows edit and delete actions for pending submissions", async () => {
|
||||
render(
|
||||
<AgentTableRow
|
||||
storeAgentSubmission={makeSubmission(SubmissionStatus.PENDING)}
|
||||
onViewSubmission={onViewSubmission}
|
||||
onDeleteSubmission={onDeleteSubmission}
|
||||
onEditSubmission={onEditSubmission}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
|
||||
|
||||
fireEvent.click(await screen.findByText("Edit"));
|
||||
expect(onEditSubmission).toHaveBeenCalledTimes(1);
|
||||
|
||||
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
|
||||
fireEvent.click(await screen.findByText("Delete"));
|
||||
expect(onDeleteSubmission).toHaveBeenCalledWith("listing-pending");
|
||||
expect(onViewSubmission).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("shows view only for non-pending submissions", async () => {
|
||||
const approvedSubmission = makeSubmission(SubmissionStatus.APPROVED);
|
||||
|
||||
render(
|
||||
<AgentTableRow
|
||||
storeAgentSubmission={approvedSubmission}
|
||||
onViewSubmission={onViewSubmission}
|
||||
onDeleteSubmission={onDeleteSubmission}
|
||||
onEditSubmission={onEditSubmission}
|
||||
/>,
|
||||
);
|
||||
|
||||
fireEvent.pointerDown(screen.getByTestId("agent-table-row-actions"));
|
||||
|
||||
const viewAction = await screen.findByText("View");
|
||||
fireEvent.click(viewAction);
|
||||
|
||||
expect(onViewSubmission).toHaveBeenCalledWith(approvedSubmission);
|
||||
expect(screen.queryByText("Edit")).toBeNull();
|
||||
expect(screen.queryByText("Delete")).toBeNull();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,147 @@
|
||||
import type { ReactNode } from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import {
|
||||
getGetV1GetNotificationPreferencesMockHandler,
|
||||
getGetV1GetUserTimezoneMockHandler,
|
||||
getPostV1UpdateNotificationPreferencesMockHandler,
|
||||
getPostV1UpdateUserEmailMockHandler,
|
||||
} from "@/app/api/__generated__/endpoints/auth/auth.msw";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import SettingsPage from "../page";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseSupabase = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
|
||||
default: ({ children }: { children: ReactNode }) => <>{children}</>,
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: mockUseSupabase,
|
||||
}));
|
||||
|
||||
const testUser = {
|
||||
id: "user-1",
|
||||
email: "user@example.com",
|
||||
app_metadata: {},
|
||||
user_metadata: {},
|
||||
aud: "authenticated",
|
||||
created_at: "2026-01-01T00:00:00.000Z",
|
||||
};
|
||||
|
||||
describe("SettingsPage", () => {
|
||||
beforeEach(() => {
|
||||
mockUseSupabase.mockReturnValue({
|
||||
user: testUser,
|
||||
isLoggedIn: true,
|
||||
isUserLoading: false,
|
||||
supabase: {},
|
||||
});
|
||||
});
|
||||
|
||||
test("renders the account actions", async () => {
|
||||
server.use(
|
||||
getGetV1GetNotificationPreferencesMockHandler({
|
||||
user_id: "user-1",
|
||||
email: "user@example.com",
|
||||
preferences: {
|
||||
AGENT_RUN: true,
|
||||
ZERO_BALANCE: false,
|
||||
LOW_BALANCE: false,
|
||||
BLOCK_EXECUTION_FAILED: true,
|
||||
CONTINUOUS_AGENT_ERROR: false,
|
||||
DAILY_SUMMARY: false,
|
||||
WEEKLY_SUMMARY: true,
|
||||
MONTHLY_SUMMARY: false,
|
||||
AGENT_APPROVED: true,
|
||||
AGENT_REJECTED: true,
|
||||
},
|
||||
daily_limit: 0,
|
||||
emails_sent_today: 0,
|
||||
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
|
||||
}),
|
||||
getGetV1GetUserTimezoneMockHandler({ timezone: "Asia/Kolkata" }),
|
||||
getPostV1UpdateUserEmailMockHandler({}),
|
||||
getPostV1UpdateNotificationPreferencesMockHandler({
|
||||
user_id: "user-1",
|
||||
email: "user@example.com",
|
||||
preferences: {},
|
||||
daily_limit: 0,
|
||||
emails_sent_today: 0,
|
||||
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
|
||||
}),
|
||||
);
|
||||
|
||||
render(<SettingsPage />);
|
||||
|
||||
const emailInput = await screen.findByLabelText("Email");
|
||||
expect((emailInput as HTMLInputElement).value).toBe("user@example.com");
|
||||
expect(
|
||||
screen.getByRole("link", { name: "Reset password" }).getAttribute("href"),
|
||||
).toBe("/reset-password");
|
||||
});
|
||||
|
||||
test("saves notification preference changes", async () => {
|
||||
let submittedPreferences:
|
||||
| {
|
||||
email: string;
|
||||
preferences: Record<string, boolean>;
|
||||
}
|
||||
| undefined;
|
||||
|
||||
server.use(
|
||||
getGetV1GetNotificationPreferencesMockHandler({
|
||||
user_id: "user-1",
|
||||
email: "user@example.com",
|
||||
preferences: {
|
||||
AGENT_RUN: false,
|
||||
ZERO_BALANCE: false,
|
||||
LOW_BALANCE: false,
|
||||
BLOCK_EXECUTION_FAILED: false,
|
||||
CONTINUOUS_AGENT_ERROR: false,
|
||||
DAILY_SUMMARY: false,
|
||||
WEEKLY_SUMMARY: false,
|
||||
MONTHLY_SUMMARY: false,
|
||||
AGENT_APPROVED: false,
|
||||
AGENT_REJECTED: false,
|
||||
},
|
||||
daily_limit: 0,
|
||||
emails_sent_today: 0,
|
||||
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
|
||||
}),
|
||||
getGetV1GetUserTimezoneMockHandler({ timezone: "Asia/Kolkata" }),
|
||||
getPostV1UpdateUserEmailMockHandler({}),
|
||||
getPostV1UpdateNotificationPreferencesMockHandler(async ({ request }) => {
|
||||
submittedPreferences = (await request.json()) as {
|
||||
email: string;
|
||||
preferences: Record<string, boolean>;
|
||||
};
|
||||
|
||||
return {
|
||||
user_id: "user-1",
|
||||
email: submittedPreferences.email,
|
||||
preferences: submittedPreferences.preferences,
|
||||
daily_limit: 0,
|
||||
emails_sent_today: 0,
|
||||
last_reset_date: new Date("2026-01-01T00:00:00.000Z"),
|
||||
};
|
||||
}),
|
||||
);
|
||||
|
||||
render(<SettingsPage />);
|
||||
|
||||
fireEvent.click(
|
||||
await screen.findByRole("switch", { name: "Agent Run Notifications" }),
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { name: "Save preferences" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(submittedPreferences?.preferences.AGENT_RUN).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,97 @@
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import type { ReactNode } from "react";
|
||||
import type { User } from "@supabase/supabase-js";
|
||||
import { afterEach, beforeEach, describe, expect, test, vi } from "vitest";
|
||||
import { EmailForm } from "../EmailForm";
|
||||
|
||||
const mockToast = vi.hoisted(() => vi.fn());
|
||||
const mockMutateAsync = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
useToast: () => ({ toast: mockToast }),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/auth/auth", () => ({
|
||||
usePostV1UpdateUserEmail: () => ({
|
||||
mutateAsync: mockMutateAsync,
|
||||
isPending: false,
|
||||
}),
|
||||
}));
|
||||
|
||||
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
|
||||
default: ({ children }: { children: ReactNode }) => <>{children}</>,
|
||||
}));
|
||||
|
||||
const testUser = {
|
||||
id: "user-1",
|
||||
email: "user@example.com",
|
||||
app_metadata: {},
|
||||
user_metadata: {},
|
||||
aud: "authenticated",
|
||||
created_at: "2026-01-01T00:00:00.000Z",
|
||||
} as User;
|
||||
|
||||
describe("EmailForm", () => {
|
||||
beforeEach(() => {
|
||||
mockToast.mockReset();
|
||||
mockMutateAsync.mockReset();
|
||||
mockMutateAsync.mockResolvedValue({});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
test("submits a changed email to both update endpoints", async () => {
|
||||
const fetchMock = vi.fn().mockResolvedValue({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
});
|
||||
|
||||
vi.stubGlobal("fetch", fetchMock);
|
||||
|
||||
render(<EmailForm user={testUser} />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText("Email"), {
|
||||
target: { value: "updated@example.com" },
|
||||
});
|
||||
fireEvent.click(screen.getByRole("button", { name: "Update email" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchMock).toHaveBeenCalledWith("/api/auth/user", {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ email: "updated@example.com" }),
|
||||
});
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(mockMutateAsync).toHaveBeenCalledWith({
|
||||
data: "updated@example.com",
|
||||
});
|
||||
});
|
||||
expect(mockToast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: "Successfully updated email",
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test("keeps submit disabled when the email has not changed", () => {
|
||||
render(<EmailForm user={testUser} />);
|
||||
|
||||
expect(
|
||||
(
|
||||
screen.getByRole("button", {
|
||||
name: "Update email",
|
||||
}) as HTMLButtonElement
|
||||
).disabled,
|
||||
).toBe(true);
|
||||
});
|
||||
});
|
||||
@@ -55,6 +55,7 @@ export function NotificationForm({ preferences, user }: NotificationFormProps) {
|
||||
</div>
|
||||
<FormControl>
|
||||
<Switch
|
||||
aria-label="Agent Run Notifications"
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
import type { ReactNode } from "react";
|
||||
import {
|
||||
render,
|
||||
screen,
|
||||
fireEvent,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import SignupPage from "../page";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseSupabase = vi.hoisted(() => vi.fn());
|
||||
const mockSignupAction = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("@/providers/onboarding/onboarding-provider", () => ({
|
||||
default: ({ children }: { children: ReactNode }) => <>{children}</>,
|
||||
}));
|
||||
|
||||
vi.mock("@/lib/supabase/hooks/useSupabase", () => ({
|
||||
useSupabase: mockUseSupabase,
|
||||
}));
|
||||
|
||||
vi.mock("../actions", () => ({
|
||||
signup: mockSignupAction,
|
||||
}));
|
||||
|
||||
describe("SignupPage", () => {
|
||||
beforeEach(() => {
|
||||
mockUseSupabase.mockReturnValue({
|
||||
supabase: {},
|
||||
user: null,
|
||||
isUserLoading: false,
|
||||
isLoggedIn: false,
|
||||
});
|
||||
mockSignupAction.mockReset();
|
||||
});
|
||||
|
||||
test("shows existing user feedback from signup action", async () => {
|
||||
mockSignupAction.mockResolvedValue({
|
||||
success: false,
|
||||
error: "user_already_exists",
|
||||
});
|
||||
|
||||
render(<SignupPage />);
|
||||
|
||||
fireEvent.change(screen.getByLabelText("Email"), {
|
||||
target: { value: "existing@example.com" },
|
||||
});
|
||||
fireEvent.change(screen.getByLabelText("Password", { selector: "input" }), {
|
||||
target: { value: "validpassword123" },
|
||||
});
|
||||
fireEvent.change(
|
||||
screen.getByLabelText("Confirm Password", { selector: "input" }),
|
||||
{
|
||||
target: { value: "validpassword123" },
|
||||
},
|
||||
);
|
||||
fireEvent.click(screen.getByRole("checkbox"));
|
||||
fireEvent.click(screen.getByRole("button", { name: "Sign up" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockSignupAction).toHaveBeenCalledWith(
|
||||
"existing@example.com",
|
||||
"validpassword123",
|
||||
"validpassword123",
|
||||
true,
|
||||
);
|
||||
});
|
||||
|
||||
expect(
|
||||
await screen.findByText("User with this email already exists"),
|
||||
).toBeDefined();
|
||||
});
|
||||
});
|
||||
@@ -1,25 +0,0 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
import type { ResponseType } from "./responseType";
|
||||
import type { BlockOutputResponseSessionId } from "./blockOutputResponseSessionId";
|
||||
import type { BlockOutputResponseOutputs } from "./blockOutputResponseOutputs";
|
||||
import type { BlockOutputResponseIsDryRun } from "./blockOutputResponseIsDryRun";
|
||||
|
||||
/**
|
||||
* Response for run_block tool.
|
||||
*/
|
||||
export interface BlockOutputResponse {
|
||||
type?: ResponseType;
|
||||
message: string;
|
||||
session_id?: BlockOutputResponseSessionId;
|
||||
block_id: string;
|
||||
block_name: string;
|
||||
outputs: BlockOutputResponseOutputs;
|
||||
success?: boolean;
|
||||
is_dry_run?: BlockOutputResponseIsDryRun;
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
import type { GraphExecutionMetaInputs } from "./graphExecutionMetaInputs";
|
||||
import type { GraphExecutionMetaCredentialInputs } from "./graphExecutionMetaCredentialInputs";
|
||||
import type { GraphExecutionMetaNodesInputMasks } from "./graphExecutionMetaNodesInputMasks";
|
||||
import type { GraphExecutionMetaPresetId } from "./graphExecutionMetaPresetId";
|
||||
import type { AgentExecutionStatus } from "./agentExecutionStatus";
|
||||
import type { GraphExecutionMetaStartedAt } from "./graphExecutionMetaStartedAt";
|
||||
import type { GraphExecutionMetaEndedAt } from "./graphExecutionMetaEndedAt";
|
||||
import type { GraphExecutionMetaShareToken } from "./graphExecutionMetaShareToken";
|
||||
import type { GraphExecutionMetaStats } from "./graphExecutionMetaStats";
|
||||
|
||||
export interface GraphExecutionMeta {
|
||||
id: string;
|
||||
user_id: string;
|
||||
graph_id: string;
|
||||
graph_version: number;
|
||||
inputs: GraphExecutionMetaInputs;
|
||||
credential_inputs: GraphExecutionMetaCredentialInputs;
|
||||
nodes_input_masks: GraphExecutionMetaNodesInputMasks;
|
||||
preset_id: GraphExecutionMetaPresetId;
|
||||
status: AgentExecutionStatus;
|
||||
/** When execution started running. Null if not yet started (QUEUED). */
|
||||
started_at?: GraphExecutionMetaStartedAt;
|
||||
/** When execution finished. Null if not yet completed (QUEUED, RUNNING, INCOMPLETE, REVIEW). */
|
||||
ended_at?: GraphExecutionMetaEndedAt;
|
||||
is_shared?: boolean;
|
||||
share_token?: GraphExecutionMetaShareToken;
|
||||
is_dry_run?: boolean;
|
||||
stats: GraphExecutionMetaStats;
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
import type { SuggestedTheme } from "./suggestedTheme";
|
||||
|
||||
/**
|
||||
* Response model for user-specific suggested prompts grouped by theme.
|
||||
*/
|
||||
export interface SuggestedPromptsResponse {
|
||||
themes: SuggestedTheme[];
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
|
||||
/**
|
||||
* A themed group of suggested prompts.
|
||||
*/
|
||||
export interface SuggestedTheme {
|
||||
name: string;
|
||||
prompts: string[];
|
||||
}
|
||||
@@ -9123,6 +9123,15 @@
|
||||
],
|
||||
"title": "ContentType"
|
||||
},
|
||||
"CostBucket": {
|
||||
"properties": {
|
||||
"bucket": { "type": "string", "title": "Bucket" },
|
||||
"count": { "type": "integer", "title": "Count" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["bucket", "count"],
|
||||
"title": "CostBucket"
|
||||
},
|
||||
"CostLogRow": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -12141,7 +12150,58 @@
|
||||
"title": "Total Cost Microdollars"
|
||||
},
|
||||
"total_requests": { "type": "integer", "title": "Total Requests" },
|
||||
"total_users": { "type": "integer", "title": "Total Users" }
|
||||
"total_users": { "type": "integer", "title": "Total Users" },
|
||||
"total_input_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Input Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"total_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"avg_input_tokens_per_request": {
|
||||
"type": "number",
|
||||
"title": "Avg Input Tokens Per Request",
|
||||
"default": 0.0
|
||||
},
|
||||
"avg_output_tokens_per_request": {
|
||||
"type": "number",
|
||||
"title": "Avg Output Tokens Per Request",
|
||||
"default": 0.0
|
||||
},
|
||||
"avg_cost_microdollars_per_request": {
|
||||
"type": "number",
|
||||
"title": "Avg Cost Microdollars Per Request",
|
||||
"default": 0.0
|
||||
},
|
||||
"cost_p50_microdollars": {
|
||||
"type": "number",
|
||||
"title": "Cost P50 Microdollars",
|
||||
"default": 0.0
|
||||
},
|
||||
"cost_p75_microdollars": {
|
||||
"type": "number",
|
||||
"title": "Cost P75 Microdollars",
|
||||
"default": 0.0
|
||||
},
|
||||
"cost_p95_microdollars": {
|
||||
"type": "number",
|
||||
"title": "Cost P95 Microdollars",
|
||||
"default": 0.0
|
||||
},
|
||||
"cost_p99_microdollars": {
|
||||
"type": "number",
|
||||
"title": "Cost P99 Microdollars",
|
||||
"default": 0.0
|
||||
},
|
||||
"cost_buckets": {
|
||||
"items": { "$ref": "#/components/schemas/CostBucket" },
|
||||
"type": "array",
|
||||
"title": "Cost Buckets",
|
||||
"default": []
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -15585,7 +15645,12 @@
|
||||
"title": "Total Cache Creation Tokens",
|
||||
"default": 0
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
"request_count": { "type": "integer", "title": "Request Count" },
|
||||
"cost_bearing_request_count": {
|
||||
"type": "integer",
|
||||
"title": "Cost Bearing Request Count",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
import { describe, expect, it, vi, beforeEach, afterEach } from "vitest";
|
||||
import {
|
||||
isWorkspaceDownloadRequest,
|
||||
isRedirectStatus,
|
||||
isTransientWorkspaceDownloadStatus,
|
||||
getWorkspaceDownloadErrorMessage,
|
||||
fetchWorkspaceDownloadOnce,
|
||||
fetchWorkspaceDownloadWithRetry,
|
||||
} from "./route.helpers";
|
||||
|
||||
describe("isWorkspaceDownloadRequest", () => {
|
||||
it("matches api/workspace/files/{id}/download pattern", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"abc-123",
|
||||
"download",
|
||||
]),
|
||||
).toBe(true);
|
||||
});
|
||||
|
||||
it("rejects paths with wrong segment count", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest(["api", "workspace", "files", "download"]),
|
||||
).toBe(false);
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
"download",
|
||||
"extra",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects paths with wrong prefix", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"v1",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
"download",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
|
||||
it("rejects paths not ending with download", () => {
|
||||
expect(
|
||||
isWorkspaceDownloadRequest([
|
||||
"api",
|
||||
"workspace",
|
||||
"files",
|
||||
"id",
|
||||
"metadata",
|
||||
]),
|
||||
).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isRedirectStatus", () => {
|
||||
it.each([301, 302, 303, 307, 308])("returns true for %d", (status) => {
|
||||
expect(isRedirectStatus(status)).toBe(true);
|
||||
});
|
||||
|
||||
it.each([200, 304, 400, 404, 500])("returns false for %d", (status) => {
|
||||
expect(isRedirectStatus(status)).toBe(false);
|
||||
});
|
||||
});
|
||||
|
||||
describe("isTransientWorkspaceDownloadStatus", () => {
|
||||
it.each([408, 429, 500, 502, 503, 504])(
|
||||
"returns true for transient %d",
|
||||
(status) => {
|
||||
expect(isTransientWorkspaceDownloadStatus(status)).toBe(true);
|
||||
},
|
||||
);
|
||||
|
||||
it.each([400, 401, 403, 404, 405])(
|
||||
"returns false for non-transient %d",
|
||||
(status) => {
|
||||
expect(isTransientWorkspaceDownloadStatus(status)).toBe(false);
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
describe("getWorkspaceDownloadErrorMessage", () => {
|
||||
it("extracts detail string from object", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage({ detail: "Not found" })).toBe(
|
||||
"Not found",
|
||||
);
|
||||
});
|
||||
|
||||
it("extracts error string from object", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage({ error: "Server error" })).toBe(
|
||||
"Server error",
|
||||
);
|
||||
});
|
||||
|
||||
it("extracts nested detail.message", () => {
|
||||
expect(
|
||||
getWorkspaceDownloadErrorMessage({
|
||||
detail: { message: "Nested error" },
|
||||
}),
|
||||
).toBe("Nested error");
|
||||
});
|
||||
|
||||
it("returns trimmed string body", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage(" error text ")).toBe(
|
||||
"error text",
|
||||
);
|
||||
});
|
||||
|
||||
it("returns null for empty string", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage("")).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for whitespace-only string", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage(" ")).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for null/undefined", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage(null)).toBeNull();
|
||||
expect(getWorkspaceDownloadErrorMessage(undefined)).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for object with empty detail", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage({ detail: "" })).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for object with no recognized keys", () => {
|
||||
expect(getWorkspaceDownloadErrorMessage({ foo: "bar" })).toBeNull();
|
||||
});
|
||||
|
||||
it("prefers detail over error", () => {
|
||||
expect(
|
||||
getWorkspaceDownloadErrorMessage({
|
||||
detail: "detail msg",
|
||||
error: "error msg",
|
||||
}),
|
||||
).toBe("detail msg");
|
||||
});
|
||||
});
|
||||
|
||||
describe("fetchWorkspaceDownloadOnce", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal("fetch", vi.fn());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("returns response directly for non-redirect status", async () => {
|
||||
const mockResponse = { ok: true, status: 200, headers: new Headers() };
|
||||
vi.mocked(fetch).mockResolvedValue(mockResponse as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {});
|
||||
expect(result).toBe(mockResponse);
|
||||
expect(fetch).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("follows redirect when Location header is present", async () => {
|
||||
const redirectResponse = {
|
||||
ok: false,
|
||||
status: 302,
|
||||
headers: new Headers({ Location: "https://storage.example.com/file" }),
|
||||
};
|
||||
const finalResponse = { ok: true, status: 200, headers: new Headers() };
|
||||
vi.mocked(fetch)
|
||||
.mockResolvedValueOnce(redirectResponse as unknown as Response)
|
||||
.mockResolvedValueOnce(finalResponse as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {
|
||||
Authorization: "Bearer token",
|
||||
});
|
||||
expect(result).toBe(finalResponse);
|
||||
expect(fetch).toHaveBeenCalledTimes(2);
|
||||
expect(fetch).toHaveBeenNthCalledWith(
|
||||
2,
|
||||
"https://storage.example.com/file",
|
||||
{ method: "GET", redirect: "follow" },
|
||||
);
|
||||
});
|
||||
|
||||
it("returns redirect response when Location header is missing", async () => {
|
||||
const redirectResponse = {
|
||||
ok: false,
|
||||
status: 307,
|
||||
headers: new Headers(),
|
||||
};
|
||||
vi.mocked(fetch).mockResolvedValue(redirectResponse as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadOnce("https://backend/file", {});
|
||||
expect(result).toBe(redirectResponse);
|
||||
expect(fetch).toHaveBeenCalledOnce();
|
||||
});
|
||||
});
|
||||
|
||||
describe("fetchWorkspaceDownloadWithRetry", () => {
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal("fetch", vi.fn());
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
it("returns immediately on success", async () => {
|
||||
const okResponse = { ok: true, status: 200, headers: new Headers() };
|
||||
vi.mocked(fetch).mockResolvedValue(okResponse as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadWithRetry(
|
||||
"https://backend/file",
|
||||
{},
|
||||
2,
|
||||
0,
|
||||
);
|
||||
expect(result).toBe(okResponse);
|
||||
expect(fetch).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("returns immediately on non-transient error without retrying", async () => {
|
||||
const notFound = { ok: false, status: 404, headers: new Headers() };
|
||||
vi.mocked(fetch).mockResolvedValue(notFound as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadWithRetry(
|
||||
"https://backend/file",
|
||||
{},
|
||||
2,
|
||||
0,
|
||||
);
|
||||
expect(result.status).toBe(404);
|
||||
expect(fetch).toHaveBeenCalledOnce();
|
||||
});
|
||||
|
||||
it("retries on transient 502 and succeeds", async () => {
|
||||
const bad = { ok: false, status: 502, headers: new Headers() };
|
||||
const ok = { ok: true, status: 200, headers: new Headers() };
|
||||
vi.mocked(fetch)
|
||||
.mockResolvedValueOnce(bad as unknown as Response)
|
||||
.mockResolvedValueOnce(ok as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadWithRetry(
|
||||
"https://backend/file",
|
||||
{},
|
||||
2,
|
||||
0,
|
||||
);
|
||||
expect(result).toBe(ok);
|
||||
expect(fetch).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("returns last transient response after exhausting retries", async () => {
|
||||
const bad = { ok: false, status: 503, headers: new Headers() };
|
||||
vi.mocked(fetch).mockResolvedValue(bad as unknown as Response);
|
||||
|
||||
const result = await fetchWorkspaceDownloadWithRetry(
|
||||
"https://backend/file",
|
||||
{},
|
||||
2,
|
||||
0,
|
||||
);
|
||||
expect(result.status).toBe(503);
|
||||
expect(fetch).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it("retries on network error and throws after exhausting retries", async () => {
|
||||
vi.mocked(fetch).mockRejectedValue(new Error("Connection reset"));
|
||||
|
||||
await expect(
|
||||
fetchWorkspaceDownloadWithRetry("https://backend/file", {}, 1, 0),
|
||||
).rejects.toThrow("Connection reset");
|
||||
expect(fetch).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,108 @@
|
||||
export function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||
return (
|
||||
path.length == 5 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "workspace" &&
|
||||
path[2] === "files" &&
|
||||
path[path.length - 1] === "download"
|
||||
);
|
||||
}
|
||||
|
||||
export function isRedirectStatus(status: number): boolean {
|
||||
return [301, 302, 303, 307, 308].includes(status);
|
||||
}
|
||||
|
||||
export function isTransientWorkspaceDownloadStatus(status: number): boolean {
|
||||
return status === 408 || status === 429 || status >= 500;
|
||||
}
|
||||
|
||||
export function sleep(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
export async function fetchWorkspaceDownloadOnce(
|
||||
backendUrl: string,
|
||||
headers: Record<string, string>,
|
||||
): Promise<Response> {
|
||||
const backendResponse = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers,
|
||||
redirect: "manual",
|
||||
});
|
||||
|
||||
if (!isRedirectStatus(backendResponse.status)) {
|
||||
return backendResponse;
|
||||
}
|
||||
|
||||
const location = backendResponse.headers.get("Location");
|
||||
if (!location) return backendResponse;
|
||||
|
||||
return await fetch(location, {
|
||||
method: "GET",
|
||||
redirect: "follow",
|
||||
});
|
||||
}
|
||||
|
||||
export async function fetchWorkspaceDownloadWithRetry(
|
||||
backendUrl: string,
|
||||
headers: Record<string, string>,
|
||||
maxRetries: number,
|
||||
retryDelayMs: number,
|
||||
): Promise<Response> {
|
||||
for (let attempt = 0; attempt <= maxRetries; attempt++) {
|
||||
try {
|
||||
const response = await fetchWorkspaceDownloadOnce(backendUrl, headers);
|
||||
if (
|
||||
response.ok ||
|
||||
!isTransientWorkspaceDownloadStatus(response.status) ||
|
||||
attempt === maxRetries
|
||||
) {
|
||||
return response;
|
||||
}
|
||||
} catch (error) {
|
||||
if (attempt === maxRetries) throw error;
|
||||
}
|
||||
|
||||
await sleep(retryDelayMs);
|
||||
}
|
||||
|
||||
throw new Error("Workspace download failed after retries");
|
||||
}
|
||||
|
||||
export function getWorkspaceDownloadErrorMessage(body: unknown): string | null {
|
||||
if (typeof body === "string") {
|
||||
const trimmed = body.trim();
|
||||
return trimmed || null;
|
||||
}
|
||||
|
||||
if (!body || typeof body !== "object") return null;
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
typeof body.detail === "string" &&
|
||||
body.detail.trim().length > 0
|
||||
) {
|
||||
return body.detail.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"error" in body &&
|
||||
typeof body.error === "string" &&
|
||||
body.error.trim().length > 0
|
||||
) {
|
||||
return body.error.trim();
|
||||
}
|
||||
|
||||
if (
|
||||
"detail" in body &&
|
||||
body.detail &&
|
||||
typeof body.detail === "object" &&
|
||||
"message" in body.detail &&
|
||||
typeof body.detail.message === "string" &&
|
||||
body.detail.message.trim().length > 0
|
||||
) {
|
||||
return body.detail.message.trim();
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
@@ -11,25 +11,17 @@ import { NextRequest, NextResponse } from "next/server";
|
||||
export const maxDuration = 300; // 5 minutes timeout for large uploads
|
||||
export const dynamic = "force-dynamic";
|
||||
|
||||
import {
|
||||
fetchWorkspaceDownloadWithRetry,
|
||||
getWorkspaceDownloadErrorMessage,
|
||||
isWorkspaceDownloadRequest,
|
||||
} from "./route.helpers";
|
||||
|
||||
function buildBackendUrl(path: string[], queryString: string): string {
|
||||
const backendPath = path.join("/");
|
||||
return `${environment.getAGPTServerBaseUrl()}/${backendPath}${queryString}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if this is a workspace file download request that needs binary response handling.
|
||||
*/
|
||||
function isWorkspaceDownloadRequest(path: string[]): boolean {
|
||||
// Match pattern: api/workspace/files/{id}/download (5 segments)
|
||||
return (
|
||||
path.length == 5 &&
|
||||
path[0] === "api" &&
|
||||
path[1] === "workspace" &&
|
||||
path[2] === "files" &&
|
||||
path[path.length - 1] === "download"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle workspace file download requests with proper binary response streaming.
|
||||
*/
|
||||
@@ -44,17 +36,15 @@ async function handleWorkspaceDownload(
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
const response = await fetchWorkspaceDownloadWithRetry(
|
||||
backendUrl,
|
||||
headers,
|
||||
redirect: "follow", // Follow redirects to signed URLs
|
||||
});
|
||||
2,
|
||||
500,
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
return NextResponse.json(
|
||||
{ error: `Failed to download file: ${response.statusText}` },
|
||||
{ status: response.status },
|
||||
);
|
||||
return await createWorkspaceDownloadErrorResponse(response);
|
||||
}
|
||||
|
||||
// Fully buffer the response before forwarding. Passing response.body as a
|
||||
@@ -81,6 +71,34 @@ async function handleWorkspaceDownload(
|
||||
});
|
||||
}
|
||||
|
||||
async function createWorkspaceDownloadErrorResponse(
|
||||
response: Response,
|
||||
): Promise<NextResponse> {
|
||||
const contentType = response.headers.get("Content-Type")?.toLowerCase() ?? "";
|
||||
|
||||
try {
|
||||
if (contentType.includes("application/json")) {
|
||||
const body = await response.json();
|
||||
return NextResponse.json(body, { status: response.status });
|
||||
}
|
||||
|
||||
const text = await response.text();
|
||||
const detail =
|
||||
getWorkspaceDownloadErrorMessage(text) ||
|
||||
response.statusText ||
|
||||
"Failed to download file";
|
||||
|
||||
return NextResponse.json({ detail }, { status: response.status });
|
||||
} catch {
|
||||
return NextResponse.json(
|
||||
{
|
||||
detail: response.statusText || "Failed to download file",
|
||||
},
|
||||
{ status: response.status },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleJsonRequest(
|
||||
req: NextRequest,
|
||||
method: string,
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
fireEvent,
|
||||
render,
|
||||
screen,
|
||||
waitFor,
|
||||
} from "@/tests/integrations/test-utils";
|
||||
import {
|
||||
getPostV2UpdateUserProfileMockHandler200,
|
||||
getPostV2UpdateUserProfileMockHandler422,
|
||||
getPostV2UpdateUserProfileResponseMock422,
|
||||
} from "@/app/api/__generated__/endpoints/store/store.msw";
|
||||
import { server } from "@/mocks/mock-server";
|
||||
import type { ProfileDetails } from "@/app/api/__generated__/models/profileDetails";
|
||||
import { ProfileInfoForm } from "../ProfileInfoForm";
|
||||
|
||||
function makeProfile(overrides: Partial<ProfileDetails> = {}): ProfileDetails {
|
||||
return {
|
||||
name: "Initial Name",
|
||||
username: "initial-user",
|
||||
description: "Initial description",
|
||||
links: [],
|
||||
avatar_url: "",
|
||||
...overrides,
|
||||
} as ProfileDetails;
|
||||
}
|
||||
|
||||
describe("ProfileInfoForm", () => {
|
||||
it("renders the existing profile values into editable fields", () => {
|
||||
render(<ProfileInfoForm profile={makeProfile({ name: "Hello World" })} />);
|
||||
const nameInput = screen.getByTestId(
|
||||
"profile-info-form-display-name",
|
||||
) as HTMLInputElement;
|
||||
expect(nameInput.defaultValue).toBe("Hello World");
|
||||
});
|
||||
|
||||
it("submits the new display name to POST /api/store/profile and reflects the response", async () => {
|
||||
let receivedBody: Record<string, unknown> | null = null;
|
||||
|
||||
server.use(
|
||||
getPostV2UpdateUserProfileMockHandler200(async ({ request }) => {
|
||||
receivedBody = (await request.json()) as Record<string, unknown>;
|
||||
return makeProfile({ name: receivedBody?.name as string });
|
||||
}),
|
||||
);
|
||||
|
||||
render(<ProfileInfoForm profile={makeProfile({ name: "Old Name" })} />);
|
||||
|
||||
const nameInput = screen.getByTestId("profile-info-form-display-name");
|
||||
fireEvent.change(nameInput, { target: { value: "Brand New Name" } });
|
||||
|
||||
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
receivedBody,
|
||||
"POST /api/store/profile must fire when the user clicks Save",
|
||||
).not.toBeNull();
|
||||
});
|
||||
|
||||
expect(receivedBody!.name).toBe("Brand New Name");
|
||||
});
|
||||
|
||||
it("does not silently swallow the request when the API returns 422", async () => {
|
||||
let calls = 0;
|
||||
server.use(
|
||||
getPostV2UpdateUserProfileMockHandler422(() => {
|
||||
calls += 1;
|
||||
return getPostV2UpdateUserProfileResponseMock422({
|
||||
detail: [
|
||||
{
|
||||
loc: ["body", "name"],
|
||||
msg: "validation error",
|
||||
type: "value_error",
|
||||
},
|
||||
],
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
render(<ProfileInfoForm profile={makeProfile()} />);
|
||||
|
||||
const nameInput = screen.getByTestId("profile-info-form-display-name");
|
||||
fireEvent.change(nameInput, { target: { value: "Anything" } });
|
||||
fireEvent.click(screen.getByRole("button", { name: "Save changes" }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
calls,
|
||||
"save click must hit the backend even when validation fails",
|
||||
).toBeGreaterThan(0);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,3 +1,5 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import type React from "react";
|
||||
import { describe, expect, it } from "vitest";
|
||||
import { csvRenderer } from "./CSVRenderer";
|
||||
|
||||
@@ -16,6 +18,16 @@ describe("csvRenderer.canRender", () => {
|
||||
it("matches .csv filename case-insensitively", () => {
|
||||
expect(csvRenderer.canRender("a,b", { filename: "data.CSV" })).toBe(true);
|
||||
});
|
||||
it("matches TSV mime type", () => {
|
||||
expect(
|
||||
csvRenderer.canRender("a\tb\n1\t2", {
|
||||
mimeType: "text/tab-separated-values",
|
||||
}),
|
||||
).toBe(true);
|
||||
});
|
||||
it("matches .tsv filename case-insensitively", () => {
|
||||
expect(csvRenderer.canRender("a\tb", { filename: "data.TSV" })).toBe(true);
|
||||
});
|
||||
it("rejects non-string values", () => {
|
||||
expect(csvRenderer.canRender(42, { mimeType: "text/csv" })).toBe(false);
|
||||
});
|
||||
@@ -64,4 +76,16 @@ describe("csvRenderer.render (parse via render output smoke)", () => {
|
||||
const csv = 'name\n"She said ""hi"""';
|
||||
expect(() => csvRenderer.render(csv)).not.toThrow();
|
||||
});
|
||||
it("renders TSV columns using tabs as the delimiter", () => {
|
||||
render(
|
||||
csvRenderer.render("name\tage\nAlice\t30", {
|
||||
filename: "data.tsv",
|
||||
}) as React.ReactElement,
|
||||
);
|
||||
|
||||
expect(screen.getByText("name")).toBeDefined();
|
||||
expect(screen.getByText("age")).toBeDefined();
|
||||
expect(screen.getByText("Alice")).toBeDefined();
|
||||
expect(screen.getByText("30")).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -6,7 +6,35 @@ import {
|
||||
CopyContent,
|
||||
} from "../types";
|
||||
|
||||
function parseCSV(text: string): { headers: string[]; rows: string[][] } {
|
||||
function normalizeMime(mime?: string): string | undefined {
|
||||
return mime?.toLowerCase().split(";")[0]?.trim();
|
||||
}
|
||||
|
||||
function getDelimiter(metadata?: OutputMetadata): "," | "\t" {
|
||||
if (
|
||||
normalizeMime(metadata?.mimeType) === "text/tab-separated-values" ||
|
||||
metadata?.filename?.toLowerCase().endsWith(".tsv")
|
||||
) {
|
||||
return "\t";
|
||||
}
|
||||
|
||||
return ",";
|
||||
}
|
||||
|
||||
function getDelimitedMimeType(metadata?: OutputMetadata): string {
|
||||
return getDelimiter(metadata) === "\t"
|
||||
? "text/tab-separated-values"
|
||||
: "text/csv";
|
||||
}
|
||||
|
||||
function getDelimitedFallbackFilename(metadata?: OutputMetadata): string {
|
||||
return getDelimiter(metadata) === "\t" ? "data.tsv" : "data.csv";
|
||||
}
|
||||
|
||||
function parseDelimitedText(
|
||||
text: string,
|
||||
delimiter: "," | "\t",
|
||||
): { headers: string[]; rows: string[][] } {
|
||||
const normalized = text
|
||||
.replace(/\r\n?/g, "\n")
|
||||
.replace(/^\ufeff/, "")
|
||||
@@ -32,7 +60,7 @@ function parseCSV(text: string): { headers: string[]; rows: string[][] } {
|
||||
}
|
||||
} else if (ch === '"') {
|
||||
inQuotes = true;
|
||||
} else if (ch === ",") {
|
||||
} else if (ch === delimiter) {
|
||||
row.push(current);
|
||||
current = "";
|
||||
} else if (ch === "\n") {
|
||||
@@ -51,8 +79,17 @@ function parseCSV(text: string): { headers: string[]; rows: string[][] } {
|
||||
return { headers, rows: rows.slice(1) };
|
||||
}
|
||||
|
||||
function CSVTable({ value }: { value: string }) {
|
||||
const { headers, rows } = useMemo(() => parseCSV(value), [value]);
|
||||
function CSVTable({
|
||||
value,
|
||||
delimiter,
|
||||
}: {
|
||||
value: string;
|
||||
delimiter: "," | "\t";
|
||||
}) {
|
||||
const { headers, rows } = useMemo(
|
||||
() => parseDelimitedText(value, delimiter),
|
||||
[delimiter, value],
|
||||
);
|
||||
const [sortCol, setSortCol] = useState<number | null>(null);
|
||||
const [sortAsc, setSortAsc] = useState(true);
|
||||
|
||||
@@ -134,16 +171,17 @@ function CSVTable({ value }: { value: string }) {
|
||||
|
||||
function canRenderCSV(value: unknown, metadata?: OutputMetadata): boolean {
|
||||
if (typeof value !== "string") return false;
|
||||
if (metadata?.mimeType === "text/csv") return true;
|
||||
const mime = normalizeMime(metadata?.mimeType);
|
||||
if (mime === "text/csv" || mime === "text/tab-separated-values") {
|
||||
return true;
|
||||
}
|
||||
if (metadata?.filename?.toLowerCase().endsWith(".csv")) return true;
|
||||
if (metadata?.filename?.toLowerCase().endsWith(".tsv")) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
function renderCSV(
|
||||
value: unknown,
|
||||
_metadata?: OutputMetadata,
|
||||
): React.ReactNode {
|
||||
return <CSVTable value={String(value)} />;
|
||||
function renderCSV(value: unknown, metadata?: OutputMetadata): React.ReactNode {
|
||||
return <CSVTable value={String(value)} delimiter={getDelimiter(metadata)} />;
|
||||
}
|
||||
|
||||
function getCopyContentCSV(
|
||||
@@ -159,10 +197,11 @@ function getDownloadContentCSV(
|
||||
metadata?: OutputMetadata,
|
||||
): DownloadContent | null {
|
||||
const text = String(value);
|
||||
const mimeType = getDelimitedMimeType(metadata);
|
||||
return {
|
||||
data: new Blob([text], { type: "text/csv" }),
|
||||
filename: metadata?.filename || "data.csv",
|
||||
mimeType: "text/csv",
|
||||
data: new Blob([text], { type: mimeType }),
|
||||
filename: metadata?.filename || getDelimitedFallbackFilename(metadata),
|
||||
mimeType,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
import { render, screen } from "@/tests/integrations/test-utils";
|
||||
import { AgentExecutionStatus } from "@/app/api/__generated__/models/agentExecutionStatus";
|
||||
import { AgentActivityDropdown } from "../AgentActivityDropdown";
|
||||
import { AgentExecutionWithInfo } from "../helpers";
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
const mockUseAgentActivityDropdown = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock("../useAgentActivityDropdown", () => ({
|
||||
useAgentActivityDropdown: mockUseAgentActivityDropdown,
|
||||
}));
|
||||
|
||||
function makeExecution(
|
||||
overrides: Partial<AgentExecutionWithInfo> = {},
|
||||
): AgentExecutionWithInfo {
|
||||
return {
|
||||
id: "exec-1",
|
||||
graph_id: "graph-1",
|
||||
status: AgentExecutionStatus.RUNNING,
|
||||
started_at: new Date(),
|
||||
ended_at: null,
|
||||
user_id: "user-1",
|
||||
graph_version: 1,
|
||||
inputs: {},
|
||||
credential_inputs: {},
|
||||
nodes_input_masks: {},
|
||||
preset_id: null,
|
||||
stats: null,
|
||||
agent_name: "Test Agent",
|
||||
agent_description: "A running agent",
|
||||
library_agent_id: "library-1",
|
||||
...overrides,
|
||||
};
|
||||
}
|
||||
|
||||
describe("AgentActivityDropdown", () => {
|
||||
beforeEach(() => {
|
||||
mockUseAgentActivityDropdown.mockReturnValue({
|
||||
activeExecutions: [makeExecution(), makeExecution({ id: "exec-2" })],
|
||||
recentCompletions: [],
|
||||
recentFailures: [],
|
||||
totalCount: 2,
|
||||
isReady: true,
|
||||
error: null,
|
||||
isOpen: false,
|
||||
setIsOpen: vi.fn(),
|
||||
});
|
||||
});
|
||||
|
||||
test("shows the active execution badge count", () => {
|
||||
render(<AgentActivityDropdown />);
|
||||
|
||||
expect(screen.getByTestId("agent-activity-badge").textContent).toContain(
|
||||
"2",
|
||||
);
|
||||
expect(screen.getByTestId("agent-activity-button")).toBeDefined();
|
||||
});
|
||||
|
||||
test("renders the dropdown content when open", async () => {
|
||||
mockUseAgentActivityDropdown.mockReturnValue({
|
||||
activeExecutions: [makeExecution()],
|
||||
recentCompletions: [],
|
||||
recentFailures: [],
|
||||
totalCount: 1,
|
||||
isReady: true,
|
||||
error: null,
|
||||
isOpen: true,
|
||||
setIsOpen: vi.fn(),
|
||||
});
|
||||
|
||||
render(<AgentActivityDropdown />);
|
||||
|
||||
expect(screen.getByTestId("agent-activity-dropdown")).toBeDefined();
|
||||
expect(await screen.findByText("Test Agent")).toBeDefined();
|
||||
});
|
||||
});
|
||||
97
autogpt_platform/frontend/src/lib/utils.test.ts
Normal file
97
autogpt_platform/frontend/src/lib/utils.test.ts
Normal file
@@ -0,0 +1,97 @@
|
||||
import { describe, expect, test } from "vitest";
|
||||
import { setNestedProperty } from "./utils";
|
||||
|
||||
const testCases = [
|
||||
{
|
||||
name: "simple property assignment",
|
||||
path: "name",
|
||||
value: "John",
|
||||
expected: { name: "John" },
|
||||
},
|
||||
{
|
||||
name: "nested property with dot notation",
|
||||
path: "user.settings.theme",
|
||||
value: "dark",
|
||||
expected: { user: { settings: { theme: "dark" } } },
|
||||
},
|
||||
{
|
||||
name: "nested property with slash notation",
|
||||
path: "user/settings/language",
|
||||
value: "en",
|
||||
expected: { user: { settings: { language: "en" } } },
|
||||
},
|
||||
{
|
||||
name: "mixed dot and slash notation",
|
||||
path: "user.settings/preferences.color",
|
||||
value: "blue",
|
||||
expected: { user: { settings: { preferences: { color: "blue" } } } },
|
||||
},
|
||||
{
|
||||
name: "overwrite primitive with object",
|
||||
path: "user.details",
|
||||
value: { age: 30 },
|
||||
expected: { user: { details: { age: 30 } } },
|
||||
},
|
||||
];
|
||||
|
||||
describe("setNestedProperty", () => {
|
||||
for (const { name, path, value, expected } of testCases) {
|
||||
test(name, () => {
|
||||
const obj = {};
|
||||
setNestedProperty(obj, path, value);
|
||||
expect(obj).toEqual(expected);
|
||||
});
|
||||
}
|
||||
|
||||
test("throws for null object", () => {
|
||||
expect(() => {
|
||||
setNestedProperty(null, "test", "value");
|
||||
}).toThrow("Target must be a non-null object");
|
||||
});
|
||||
|
||||
test("throws for undefined object", () => {
|
||||
expect(() => {
|
||||
setNestedProperty(undefined, "test", "value");
|
||||
}).toThrow("Target must be a non-null object");
|
||||
});
|
||||
|
||||
test("throws for non-object target", () => {
|
||||
expect(() => {
|
||||
setNestedProperty("string", "test", "value");
|
||||
}).toThrow("Target must be a non-null object");
|
||||
});
|
||||
|
||||
test("throws for empty path", () => {
|
||||
expect(() => {
|
||||
setNestedProperty({}, "", "value");
|
||||
}).toThrow("Path must be a non-empty string");
|
||||
});
|
||||
|
||||
test("throws for __proto__ access", () => {
|
||||
expect(() => {
|
||||
setNestedProperty({}, "__proto__.malicious", "attack");
|
||||
}).toThrow("Invalid property name: __proto__");
|
||||
});
|
||||
|
||||
test("throws for constructor access", () => {
|
||||
expect(() => {
|
||||
setNestedProperty({}, "constructor.prototype.malicious", "attack");
|
||||
}).toThrow("Invalid property name: constructor");
|
||||
});
|
||||
|
||||
test("throws for prototype access", () => {
|
||||
expect(() => {
|
||||
setNestedProperty({}, "obj.prototype.malicious", "attack");
|
||||
}).toThrow("Invalid property name: prototype");
|
||||
});
|
||||
|
||||
test("prevents prototype pollution", () => {
|
||||
const obj = {};
|
||||
|
||||
expect(() => {
|
||||
setNestedProperty(obj, "__proto__.polluted", true);
|
||||
}).toThrow("Invalid property name: __proto__");
|
||||
|
||||
expect(({} as { polluted?: boolean }).polluted).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,100 @@
|
||||
import { randomUUID } from "crypto";
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { E2E_AUTH_STATES } from "./credentials/accounts";
|
||||
|
||||
test.use({ storageState: E2E_AUTH_STATES.parallelB });
|
||||
|
||||
test("api keys happy path: user can create, copy, and revoke an API key", async ({
|
||||
page,
|
||||
context,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
await context.grantPermissions(["clipboard-read", "clipboard-write"]);
|
||||
|
||||
const keyName = `E2E CLI Key ${randomUUID().slice(0, 8)}`;
|
||||
|
||||
await page.goto("/profile/api-keys");
|
||||
await expect(page).toHaveURL(/\/profile\/api-keys/);
|
||||
await expect(
|
||||
page.getByText(
|
||||
"Manage your AutoGPT Platform API keys for programmatic access",
|
||||
),
|
||||
).toBeVisible();
|
||||
|
||||
await page.getByRole("button", { name: "Create Key" }).click();
|
||||
await page.getByLabel("Name").fill(keyName);
|
||||
const executeGraphCheckbox = page.getByRole("checkbox", {
|
||||
name: /EXECUTE_GRAPH/i,
|
||||
});
|
||||
const executeGraphChecked =
|
||||
(await executeGraphCheckbox.getAttribute("aria-checked")) === "true";
|
||||
if (!executeGraphChecked) {
|
||||
await executeGraphCheckbox.click();
|
||||
}
|
||||
await expect(executeGraphCheckbox).toHaveAttribute("aria-checked", "true");
|
||||
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
const secretDialog = page.getByRole("dialog", {
|
||||
name: "AutoGPT Platform API Key Created",
|
||||
});
|
||||
await expect
|
||||
.poll(
|
||||
async () => {
|
||||
if (await secretDialog.isVisible().catch(() => false)) {
|
||||
return "created";
|
||||
}
|
||||
|
||||
const creationFailed = await page
|
||||
.getByText("Failed to create AutoGPT Platform API key")
|
||||
.isVisible()
|
||||
.catch(() => false);
|
||||
if (creationFailed) {
|
||||
return "failed";
|
||||
}
|
||||
|
||||
return "pending";
|
||||
},
|
||||
{
|
||||
timeout: 30000,
|
||||
message:
|
||||
"API key creation should either open the created-key dialog or surface an explicit failure toast",
|
||||
},
|
||||
)
|
||||
.toBe("created");
|
||||
await expect(secretDialog).toBeVisible();
|
||||
|
||||
const createdSecret = (
|
||||
(await secretDialog.locator("code").textContent()) ?? ""
|
||||
).trim();
|
||||
expect(createdSecret.length).toBeGreaterThan(0);
|
||||
|
||||
await secretDialog.getByRole("button").first().click();
|
||||
await expect(page.getByText("Copied", { exact: true })).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
await expect
|
||||
.poll(() => page.evaluate(() => navigator.clipboard.readText()), {
|
||||
timeout: 10000,
|
||||
})
|
||||
.toBe(createdSecret);
|
||||
|
||||
await secretDialog.getByRole("button", { name: "Close" }).first().click();
|
||||
|
||||
const createdKeyRow = page
|
||||
.getByTestId("api-key-row")
|
||||
.filter({ hasText: keyName })
|
||||
.first();
|
||||
await expect(createdKeyRow).toBeVisible({ timeout: 15000 });
|
||||
|
||||
await createdKeyRow.getByTestId("api-key-actions").click();
|
||||
await page.getByRole("menuitem", { name: "Revoke" }).click();
|
||||
|
||||
await expect(
|
||||
page.getByText("AutoGPT Platform API key revoked successfully"),
|
||||
).toBeVisible({ timeout: 15000 });
|
||||
await expect(
|
||||
page.getByTestId("api-key-row").filter({ hasText: keyName }),
|
||||
).toHaveCount(0);
|
||||
});
|
||||
158
autogpt_platform/frontend/src/playwright/auth-happy-path.spec.ts
Normal file
158
autogpt_platform/frontend/src/playwright/auth-happy-path.spec.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { getSeededTestUser } from "./credentials/accounts";
|
||||
import { BuildPage } from "./pages/build.page";
|
||||
import { LoginPage } from "./pages/login.page";
|
||||
import {
|
||||
completeOnboardingWizard,
|
||||
skipOnboardingIfPresent,
|
||||
} from "./utils/onboarding";
|
||||
import { signupTestUser } from "./utils/signup";
|
||||
|
||||
test("auth happy path: user can sign up with a fresh account", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(60000);
|
||||
|
||||
await signupTestUser(page, undefined, undefined, false);
|
||||
await expect(page).toHaveURL(/\/onboarding/);
|
||||
await expect(page.getByText("Welcome to AutoGPT")).toBeVisible();
|
||||
});
|
||||
|
||||
test("auth happy path: user can sign up, enter the app, and log out", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(90000);
|
||||
|
||||
await signupTestUser(page, undefined, undefined, false);
|
||||
await expect(page).toHaveURL(/\/onboarding/);
|
||||
await expect(page.getByText("Welcome to AutoGPT")).toBeVisible();
|
||||
|
||||
await skipOnboardingIfPresent(page, "/marketplace");
|
||||
await expect(page).toHaveURL(/\/marketplace/);
|
||||
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
|
||||
|
||||
await page.getByTestId("profile-popout-menu-trigger").click();
|
||||
await page.getByRole("button", { name: "Log out" }).click();
|
||||
|
||||
await expect(page).toHaveURL(/\/login/);
|
||||
|
||||
await page.goto("/library");
|
||||
await expect(page).toHaveURL(/\/login\?next=%2Flibrary/);
|
||||
});
|
||||
|
||||
test("auth happy path: seeded user can log in", async ({ page }) => {
|
||||
test.setTimeout(60000);
|
||||
|
||||
const testUser = getSeededTestUser("smokeAuth");
|
||||
const loginPage = new LoginPage(page);
|
||||
|
||||
await page.goto("/login");
|
||||
await loginPage.login(testUser.email, testUser.password);
|
||||
|
||||
await expect(page).toHaveURL(/\/marketplace/);
|
||||
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
|
||||
});
|
||||
|
||||
test("auth happy path: seeded user can log out and protected routes redirect to login", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(60000);
|
||||
|
||||
const testUser = getSeededTestUser("primary");
|
||||
const loginPage = new LoginPage(page);
|
||||
|
||||
await page.goto("/login");
|
||||
await loginPage.login(testUser.email, testUser.password);
|
||||
|
||||
await expect(page).toHaveURL(/\/marketplace/);
|
||||
await page.getByTestId("profile-popout-menu-trigger").click();
|
||||
await page.getByRole("button", { name: "Log out" }).click();
|
||||
|
||||
await expect(page).toHaveURL(/\/login/, { timeout: 15000 });
|
||||
|
||||
await page.goto("/profile");
|
||||
await expect(page).toHaveURL(/\/login\?next=%2Fprofile/);
|
||||
});
|
||||
|
||||
test("auth happy path: user can complete onboarding and land in the app", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(60000);
|
||||
|
||||
await signupTestUser(page, undefined, undefined, false);
|
||||
await expect(page).toHaveURL(/\/onboarding/);
|
||||
|
||||
await completeOnboardingWizard(page, {
|
||||
name: "Smoke User",
|
||||
role: "Engineering",
|
||||
painPoints: ["Research", "Reports & data"],
|
||||
});
|
||||
|
||||
await expect(page).toHaveURL(/\/copilot/);
|
||||
await expect(page.getByTestId("profile-popout-menu-trigger")).toBeVisible();
|
||||
});
|
||||
|
||||
test("auth happy path: multi-tab logout clears shared builder sessions", async ({
|
||||
context,
|
||||
}) => {
|
||||
// Two pages + builder load + logout sequence justifies a higher timeout
|
||||
test.setTimeout(90000);
|
||||
|
||||
const consoleErrors: string[] = [];
|
||||
|
||||
const page1 = await context.newPage();
|
||||
const page2 = await context.newPage();
|
||||
const buildPage = new BuildPage(page1);
|
||||
|
||||
const recordWebSocketErrors =
|
||||
(label: string) => (msg: { type: () => string; text: () => string }) => {
|
||||
if (msg.type() === "error" && msg.text().includes("WebSocket")) {
|
||||
consoleErrors.push(`${label}: ${msg.text()}`);
|
||||
}
|
||||
};
|
||||
|
||||
page1.on("console", recordWebSocketErrors("page1"));
|
||||
page2.on("console", recordWebSocketErrors("page2"));
|
||||
|
||||
await signupTestUser(page1, undefined, undefined, false);
|
||||
await expect(page1).toHaveURL(/\/onboarding/);
|
||||
await skipOnboardingIfPresent(page1, "/build");
|
||||
|
||||
await page1.goto("/build");
|
||||
await expect(page1).toHaveURL(/\/build/);
|
||||
await buildPage.closeTutorial();
|
||||
await expect(page1.getByTestId("profile-popout-menu-trigger")).toBeVisible();
|
||||
|
||||
await page2.goto("/build");
|
||||
await expect(page2).toHaveURL(/\/build/);
|
||||
await expect(page2.getByTestId("profile-popout-menu-trigger")).toBeVisible();
|
||||
|
||||
await page1.getByTestId("profile-popout-menu-trigger").click();
|
||||
await page1.getByRole("button", { name: "Log out" }).click();
|
||||
await expect(page1).toHaveURL(/\/login/);
|
||||
|
||||
await page2.reload();
|
||||
await expect(page2).toHaveURL(/\/login\?next=%2Fbuild/);
|
||||
await expect(page2.getByTestId("profile-popout-menu-trigger")).toBeHidden();
|
||||
|
||||
expect(consoleErrors).toHaveLength(0);
|
||||
|
||||
// Prove the auth token is actually gone, not just the UI hidden. Supabase
|
||||
// overwrites the cookie on signout with an empty value + past expiry
|
||||
// rather than deleting it. An assertion that is silently skipped when the
|
||||
// cookie is missing under the expected name would hide a real regression,
|
||||
// so we assert on every non-empty sb-*auth-token* cookie explicitly.
|
||||
const cookiesAfterLogout = await context.cookies();
|
||||
const authCookies = cookiesAfterLogout.filter(
|
||||
(c) => c.name.startsWith("sb-") && c.name.includes("auth-token"),
|
||||
);
|
||||
for (const cookie of authCookies) {
|
||||
expect(
|
||||
cookie.value,
|
||||
`supabase auth cookie ${cookie.name} must be empty after logout`,
|
||||
).toBe("");
|
||||
}
|
||||
|
||||
await page1.close();
|
||||
await page2.close();
|
||||
});
|
||||
@@ -0,0 +1,83 @@
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { E2E_AUTH_STATES } from "./credentials/accounts";
|
||||
import { BuildPage } from "./pages/build.page";
|
||||
|
||||
test.use({ storageState: E2E_AUTH_STATES.builder });
|
||||
|
||||
test("builder happy path: user can walk through the builder tutorial and cancel midway, persisting canceled state", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(180000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
await buildPage.startTutorial();
|
||||
await buildPage.walkWelcomeToBlockMenu();
|
||||
await buildPage.walkSearchAndAddCalculator();
|
||||
await buildPage.cancelTutorial();
|
||||
|
||||
expect(await buildPage.getTutorialStateFromStorage()).toBe("canceled");
|
||||
expect(await buildPage.getNodeCount()).toBeGreaterThanOrEqual(1);
|
||||
});
|
||||
|
||||
test("builder happy path: user can skip the builder tutorial from the welcome step", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(60000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
await buildPage.startTutorial();
|
||||
await buildPage.skipTutorialFromWelcome();
|
||||
});
|
||||
|
||||
test("builder happy path: user can create a simple agent in builder with core blocks", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
await buildPage.open();
|
||||
await buildPage.addSimpleAgentBlocks();
|
||||
|
||||
await expect(buildPage.getNodeLocator()).toHaveCount(2);
|
||||
await expect(
|
||||
buildPage
|
||||
.getNodeLocator(0)
|
||||
.locator('input[placeholder="Enter string value..."]'),
|
||||
).toHaveValue("smoke-value");
|
||||
await expect(buildPage.getNodeTextInput("Add to Dictionary", 0)).toHaveValue(
|
||||
"smoke-key",
|
||||
);
|
||||
await expect(buildPage.getNodeTextInput("Add to Dictionary", 1)).toHaveValue(
|
||||
"smoke-value",
|
||||
);
|
||||
});
|
||||
|
||||
test("builder happy path: user can save the created agent", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
await buildPage.createAndSaveSimpleAgent("Smoke Save Agent");
|
||||
|
||||
await expect(page).toHaveURL(/flowID=/);
|
||||
expect(await buildPage.isRunButtonEnabled()).toBeTruthy();
|
||||
});
|
||||
|
||||
test("builder happy path: user can run the saved agent from builder and see execution state", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
await buildPage.createAndSaveSimpleAgent("Smoke Run Agent");
|
||||
|
||||
await buildPage.startRun();
|
||||
await expect(
|
||||
page.locator('[data-id="stop-graph-button"], [data-id="run-graph-button"]'),
|
||||
).toBeVisible({ timeout: 15000 });
|
||||
|
||||
await expect
|
||||
.poll(() => buildPage.getExecutionState(), { timeout: 15000 })
|
||||
.not.toBe("unknown");
|
||||
});
|
||||
@@ -0,0 +1,44 @@
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { E2E_AUTH_STATES } from "./credentials/accounts";
|
||||
import { CopilotPage } from "./pages/copilot.page";
|
||||
|
||||
test.use({ storageState: E2E_AUTH_STATES.marketplace });
|
||||
|
||||
test("copilot happy path: user can create a deterministic AutoPilot session and keep it after reload", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const copilotPage = new CopilotPage(page);
|
||||
await copilotPage.open();
|
||||
|
||||
const sessionId = await copilotPage.createSessionViaApi();
|
||||
|
||||
await copilotPage.open(sessionId);
|
||||
await copilotPage.waitForChatInput();
|
||||
|
||||
await page.reload();
|
||||
await page.waitForLoadState("domcontentloaded");
|
||||
await copilotPage.dismissNotificationPrompt();
|
||||
|
||||
await expect
|
||||
.poll(() => new URL(page.url()).searchParams.get("sessionId"), {
|
||||
timeout: 15000,
|
||||
})
|
||||
.toBe(sessionId);
|
||||
await copilotPage.waitForChatInput();
|
||||
|
||||
// Sending a message must render the user's prompt in the conversation
|
||||
// immediately. This catches a regression where the chat input accepts
|
||||
// text but Enter is a no-op, without depending on knowing the exact
|
||||
// backend endpoint name (which has shifted historically).
|
||||
const userPrompt = `ping from e2e ${Date.now().toString().slice(-6)}`;
|
||||
const chatInput = copilotPage.getChatInput();
|
||||
await chatInput.fill(userPrompt);
|
||||
await chatInput.press("Enter");
|
||||
|
||||
await expect(
|
||||
page.getByText(userPrompt, { exact: false }).first(),
|
||||
"user's typed prompt must appear in the chat after pressing Enter",
|
||||
).toBeVisible({ timeout: 15000 });
|
||||
});
|
||||
@@ -0,0 +1,85 @@
|
||||
import path from "path";
|
||||
|
||||
export const SEEDED_TEST_PASSWORD =
|
||||
process.env.SEEDED_TEST_PASSWORD || "testpassword123";
|
||||
export const SEEDED_USER_POOL_VERSION = "2.0.0";
|
||||
|
||||
export const SEEDED_TEST_ACCOUNTS = {
|
||||
primary: {
|
||||
key: "primary",
|
||||
email: "test123@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
smokeAuth: {
|
||||
key: "smokeAuth",
|
||||
email: "e2e.qa.auth@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
smokeBuilder: {
|
||||
key: "smokeBuilder",
|
||||
email: "e2e.qa.builder@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
smokeLibrary: {
|
||||
key: "smokeLibrary",
|
||||
email: "e2e.qa.library@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
smokeMarketplace: {
|
||||
key: "smokeMarketplace",
|
||||
email: "e2e.qa.marketplace@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
smokeSettings: {
|
||||
key: "smokeSettings",
|
||||
email: "e2e.qa.settings@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
parallelA: {
|
||||
key: "parallelA",
|
||||
email: "e2e.qa.parallel.a@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
parallelB: {
|
||||
key: "parallelB",
|
||||
email: "e2e.qa.parallel.b@example.com",
|
||||
password: SEEDED_TEST_PASSWORD,
|
||||
},
|
||||
} as const;
|
||||
|
||||
export type SeededTestAccountKey = keyof typeof SEEDED_TEST_ACCOUNTS;
|
||||
export type SeededTestAccount =
|
||||
(typeof SEEDED_TEST_ACCOUNTS)[SeededTestAccountKey];
|
||||
|
||||
export const SEEDED_TEST_USERS = Object.values(SEEDED_TEST_ACCOUNTS);
|
||||
export const SEEDED_AUTH_STATE_ACCOUNT_KEYS = [
|
||||
"smokeBuilder",
|
||||
"smokeLibrary",
|
||||
"smokeMarketplace",
|
||||
"smokeSettings",
|
||||
"parallelA",
|
||||
"parallelB",
|
||||
] as const;
|
||||
|
||||
export const AUTH_DIRECTORY = path.resolve(process.cwd(), ".auth");
|
||||
|
||||
export function getAuthStatePath(accountKey: SeededTestAccountKey) {
|
||||
return path.join(AUTH_DIRECTORY, "states", `${accountKey}.json`);
|
||||
}
|
||||
|
||||
export const E2E_AUTH_STATES = {
|
||||
builder: getAuthStatePath("smokeBuilder"),
|
||||
library: getAuthStatePath("smokeLibrary"),
|
||||
marketplace: getAuthStatePath("smokeMarketplace"),
|
||||
settings: getAuthStatePath("smokeSettings"),
|
||||
parallelA: getAuthStatePath("parallelA"),
|
||||
parallelB: getAuthStatePath("parallelB"),
|
||||
} as const;
|
||||
|
||||
export const SMOKE_AUTH_STATES = E2E_AUTH_STATES;
|
||||
|
||||
export function getSeededTestUser(
|
||||
accountKey: SeededTestAccountKey = "primary",
|
||||
): SeededTestAccount {
|
||||
return SEEDED_TEST_ACCOUNTS[accountKey];
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
import { getSeededTestUser } from "./accounts";
|
||||
|
||||
// E2E Test Credentials and Constants
|
||||
export const TEST_CREDENTIALS = getSeededTestUser("primary");
|
||||
|
||||
export function getTestUserWithLibraryAgents() {
|
||||
return TEST_CREDENTIALS;
|
||||
}
|
||||
|
||||
// Dummy constant to help developers identify agents that don't need input
|
||||
export const DummyInput = "DummyInput";
|
||||
|
||||
// This will be used for testing agent submission for test123@example.com
|
||||
export const TEST_AGENT_DATA = {
|
||||
name: "E2E Calculator Agent",
|
||||
description:
|
||||
"A deterministic marketplace agent built from Calculator and Agent Output blocks for frontend E2E coverage.",
|
||||
image_urls: [
|
||||
"https://picsum.photos/seed/e2e-marketplace-1/200/300",
|
||||
"https://picsum.photos/seed/e2e-marketplace-2/200/301",
|
||||
"https://picsum.photos/seed/e2e-marketplace-3/200/302",
|
||||
],
|
||||
video_url: "https://www.youtube.com/watch?v=test123",
|
||||
sub_heading: "A deterministic calculator agent for PR E2E coverage",
|
||||
categories: ["test", "demo", "frontend"],
|
||||
changes_summary: "Initial deterministic calculator submission",
|
||||
} as const;
|
||||
@@ -0,0 +1,23 @@
|
||||
export function buildCookieConsentStorageState(
|
||||
origin: string = "http://localhost:3000",
|
||||
) {
|
||||
return {
|
||||
cookies: [],
|
||||
origins: [
|
||||
{
|
||||
origin,
|
||||
localStorage: [
|
||||
{
|
||||
name: "autogpt_cookie_consent",
|
||||
value: JSON.stringify({
|
||||
hasConsented: true,
|
||||
timestamp: Date.now(),
|
||||
analytics: true,
|
||||
monitoring: true,
|
||||
}),
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
49
autogpt_platform/frontend/src/playwright/global-setup.ts
Normal file
49
autogpt_platform/frontend/src/playwright/global-setup.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { FullConfig } from "@playwright/test";
|
||||
import {
|
||||
ensureSeededAuthStates,
|
||||
getInvalidSeededAuthStateKeys,
|
||||
} from "./utils/auth";
|
||||
|
||||
function resolveBaseURL(config: FullConfig) {
|
||||
const configuredBaseURL =
|
||||
config.projects[0]?.use?.baseURL ?? "http://localhost:3000";
|
||||
|
||||
if (typeof configuredBaseURL !== "string") {
|
||||
throw new Error(
|
||||
`Playwright baseURL must be a string during global setup. Received ${String(
|
||||
configuredBaseURL,
|
||||
)}.`,
|
||||
);
|
||||
}
|
||||
|
||||
return configuredBaseURL;
|
||||
}
|
||||
|
||||
async function globalSetup(config: FullConfig) {
|
||||
console.log("🚀 Starting global test setup...");
|
||||
|
||||
try {
|
||||
const baseURL = resolveBaseURL(config);
|
||||
const invalidKeys = await getInvalidSeededAuthStateKeys(baseURL);
|
||||
|
||||
if (invalidKeys.length === 0) {
|
||||
console.log("♻️ Reusing stored seeded auth states");
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(
|
||||
`🔐 Refreshing seeded auth states for: ${invalidKeys.join(", ")}`,
|
||||
);
|
||||
await ensureSeededAuthStates(baseURL);
|
||||
|
||||
console.log("✅ Global setup completed successfully!");
|
||||
} catch (error) {
|
||||
console.error("❌ Global setup failed:", error);
|
||||
console.error(
|
||||
"💡 Run backend/test/e2e_test_data.py to seed the deterministic Playwright accounts before retrying.",
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export default globalSetup;
|
||||
@@ -0,0 +1,559 @@
|
||||
import path from "path";
|
||||
import type { Page } from "@playwright/test";
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { E2E_AUTH_STATES } from "./credentials/accounts";
|
||||
import { BuildPage, createUniqueAgentName } from "./pages/build.page";
|
||||
import {
|
||||
clickRunButton,
|
||||
dismissFeedbackDialog,
|
||||
getActiveItemId,
|
||||
importAgentFromFile,
|
||||
LibraryPage,
|
||||
} from "./pages/library.page";
|
||||
|
||||
test.use({ storageState: E2E_AUTH_STATES.library });
|
||||
|
||||
const TEST_AGENT_PATH = path.resolve(__dirname, "assets", "testing_agent.json");
|
||||
const CALCULATOR_BLOCK_ID = "b1ab9b19-67a6-406d-abf5-2dba76d00c79";
|
||||
const AGENT_OUTPUT_BLOCK_ID = "363ae599-353e-4804-937e-b2ee3cef3da4";
|
||||
const STOPPED_RUN_STATUSES = new Set([
|
||||
"terminated",
|
||||
"failed",
|
||||
"incomplete",
|
||||
"completed",
|
||||
]);
|
||||
|
||||
type UploadedGraphNode = {
|
||||
id: string;
|
||||
block_id: string;
|
||||
input_default: Record<string, unknown>;
|
||||
metadata: {
|
||||
position: {
|
||||
x: number;
|
||||
y: number;
|
||||
};
|
||||
};
|
||||
input_links: unknown[];
|
||||
output_links: unknown[];
|
||||
};
|
||||
|
||||
function createLongRunningCalculatorGraph(
|
||||
agentName: string,
|
||||
calculatorCount: number = 150,
|
||||
) {
|
||||
const nodes: UploadedGraphNode[] = Array.from(
|
||||
{ length: calculatorCount },
|
||||
(_, index) => ({
|
||||
id: `calc-${index + 1}`,
|
||||
block_id: CALCULATOR_BLOCK_ID,
|
||||
input_default:
|
||||
index === 0
|
||||
? {
|
||||
operation: "Add",
|
||||
a: 1,
|
||||
b: 1,
|
||||
round_result: false,
|
||||
}
|
||||
: {
|
||||
operation: "Add",
|
||||
b: 1,
|
||||
round_result: false,
|
||||
},
|
||||
metadata: {
|
||||
position: { x: 320 * index, y: 120 },
|
||||
},
|
||||
input_links: [],
|
||||
output_links: [],
|
||||
}),
|
||||
);
|
||||
|
||||
const links = Array.from({ length: calculatorCount - 1 }, (_, index) => ({
|
||||
source_id: `calc-${index + 1}`,
|
||||
sink_id: `calc-${index + 2}`,
|
||||
source_name: "result",
|
||||
sink_name: "a",
|
||||
}));
|
||||
|
||||
nodes.push({
|
||||
id: "final-output",
|
||||
block_id: AGENT_OUTPUT_BLOCK_ID,
|
||||
input_default: {
|
||||
name: "Final result",
|
||||
description: "Long-running calculator chain output",
|
||||
},
|
||||
metadata: {
|
||||
position: { x: 320 * calculatorCount, y: 120 },
|
||||
},
|
||||
input_links: [],
|
||||
output_links: [],
|
||||
});
|
||||
links.push({
|
||||
source_id: `calc-${calculatorCount}`,
|
||||
sink_id: "final-output",
|
||||
source_name: "result",
|
||||
sink_name: "value",
|
||||
});
|
||||
|
||||
return {
|
||||
name: agentName,
|
||||
description:
|
||||
"Deterministic long-running calculator chain for runner stop coverage",
|
||||
is_active: true,
|
||||
nodes,
|
||||
links,
|
||||
};
|
||||
}
|
||||
|
||||
async function createLongRunningSavedAgent(
|
||||
page: Page,
|
||||
agentName: string,
|
||||
): Promise<{ graphId: string; graphVersion: number }> {
|
||||
const response = await page.request.post("/api/proxy/api/graphs", {
|
||||
data: {
|
||||
graph: createLongRunningCalculatorGraph(agentName),
|
||||
source: "upload",
|
||||
},
|
||||
});
|
||||
expect(response.ok(), "expected graph creation API request to succeed").toBe(
|
||||
true,
|
||||
);
|
||||
|
||||
const body = (await response.json()) as {
|
||||
id?: string;
|
||||
version?: number;
|
||||
data?: { id?: string; version?: number };
|
||||
};
|
||||
expect(
|
||||
body.data?.id ?? body.id,
|
||||
"graph creation should return a graph id",
|
||||
).toBeTruthy();
|
||||
|
||||
return {
|
||||
graphId: String(body.data?.id ?? body.id),
|
||||
graphVersion: Number(body.data?.version ?? body.version ?? 1),
|
||||
};
|
||||
}
|
||||
|
||||
async function createDeterministicCalculatorSavedAgent(
|
||||
page: Page,
|
||||
agentName: string,
|
||||
outputName: string,
|
||||
): Promise<void> {
|
||||
const response = await page.request.post("/api/proxy/api/graphs", {
|
||||
data: {
|
||||
graph: {
|
||||
name: agentName,
|
||||
description:
|
||||
"Deterministic calculator output for run-result assertions",
|
||||
is_active: true,
|
||||
nodes: [
|
||||
{
|
||||
id: "calc-1",
|
||||
block_id: CALCULATOR_BLOCK_ID,
|
||||
input_default: {
|
||||
operation: "Add",
|
||||
a: 1,
|
||||
b: 1,
|
||||
round_result: false,
|
||||
},
|
||||
metadata: {
|
||||
position: { x: 120, y: 160 },
|
||||
},
|
||||
input_links: [],
|
||||
output_links: [],
|
||||
},
|
||||
{
|
||||
id: "final-output",
|
||||
block_id: AGENT_OUTPUT_BLOCK_ID,
|
||||
input_default: {
|
||||
name: outputName,
|
||||
description: "Deterministic result output",
|
||||
},
|
||||
metadata: {
|
||||
position: { x: 520, y: 160 },
|
||||
},
|
||||
input_links: [],
|
||||
output_links: [],
|
||||
},
|
||||
],
|
||||
links: [
|
||||
{
|
||||
source_id: "calc-1",
|
||||
sink_id: "final-output",
|
||||
source_name: "result",
|
||||
sink_name: "value",
|
||||
},
|
||||
],
|
||||
},
|
||||
source: "upload",
|
||||
},
|
||||
});
|
||||
expect(
|
||||
response.ok(),
|
||||
"expected deterministic calculator graph creation API request to succeed",
|
||||
).toBe(true);
|
||||
}
|
||||
|
||||
async function getExecutionStatusFromApi(
|
||||
page: Page,
|
||||
graphId: string,
|
||||
runId: string,
|
||||
): Promise<string> {
|
||||
const response = await page.request.get(
|
||||
`/api/proxy/api/graphs/${graphId}/executions/${runId}`,
|
||||
);
|
||||
expect(response.ok(), "execution details API should succeed").toBe(true);
|
||||
|
||||
const body = (await response.json()) as { status?: string };
|
||||
return body.status?.toLowerCase() ?? "unknown";
|
||||
}
|
||||
|
||||
async function createAndSaveDeterministicOutputAgent(
|
||||
page: Page,
|
||||
prefix: string,
|
||||
): Promise<{ agentName: string; expectedOutput: string; outputName: string }> {
|
||||
const buildPage = new BuildPage(page);
|
||||
const agentName = createUniqueAgentName(prefix);
|
||||
const expectedOutput = `e2e-output-${Date.now()}`;
|
||||
const outputName = `e2e-result-${Date.now()}`;
|
||||
|
||||
await buildPage.open();
|
||||
await buildPage.addBlockByClick("Store Value");
|
||||
await buildPage.waitForNodeOnCanvas(1);
|
||||
await buildPage.fillBlockInputByPlaceholder(
|
||||
"Enter string value...",
|
||||
expectedOutput,
|
||||
0,
|
||||
);
|
||||
|
||||
await buildPage.addBlockByClick("Agent Output");
|
||||
await buildPage.waitForNodeOnCanvas(2);
|
||||
await buildPage.connectNodes(0, 1);
|
||||
await buildPage.fillLastNodeTextInput("Agent Output", outputName);
|
||||
|
||||
await buildPage.saveAgent(
|
||||
agentName,
|
||||
"Deterministic output agent for library run verification",
|
||||
);
|
||||
await buildPage.waitForSaveComplete();
|
||||
await buildPage.waitForSaveButton();
|
||||
|
||||
return { agentName, expectedOutput, outputName };
|
||||
}
|
||||
|
||||
test("library happy path: user can import an agent file into Library", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const { importedAgent } = await importAgentFromFile(
|
||||
page,
|
||||
TEST_AGENT_PATH,
|
||||
createUniqueAgentName("E2E Import Agent"),
|
||||
);
|
||||
|
||||
expect(importedAgent.name).toContain("E2E Import Agent");
|
||||
});
|
||||
|
||||
test("library happy path: user can open the imported or saved agent from Library in builder", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const { libraryPage, importedAgent } = await importAgentFromFile(
|
||||
page,
|
||||
TEST_AGENT_PATH,
|
||||
createUniqueAgentName("E2E Open Agent"),
|
||||
);
|
||||
|
||||
// Register the popup listener before clicking so we don't miss a fast open.
|
||||
// A short timeout covers the case where the link opens in the current tab.
|
||||
const popupPromise = page
|
||||
.context()
|
||||
.waitForEvent("page", { timeout: 10000 })
|
||||
.catch(() => null);
|
||||
await libraryPage.clickOpenInBuilder(importedAgent);
|
||||
const builderPage = (await popupPromise) ?? page;
|
||||
|
||||
await builderPage.waitForLoadState("domcontentloaded");
|
||||
await expect(builderPage).toHaveURL(/\/build/);
|
||||
const importedBuildPage = new BuildPage(builderPage);
|
||||
await importedBuildPage.waitForNodeOnCanvas();
|
||||
expect(await importedBuildPage.getNodeCount()).toBeGreaterThan(0);
|
||||
if (builderPage !== page) {
|
||||
await builderPage.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("library happy path: user can start and stop a saved task from runner UI", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(180000);
|
||||
|
||||
const agentName = createUniqueAgentName("E2E Stop Task Agent");
|
||||
const { graphId } = await createLongRunningSavedAgent(page, agentName);
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
|
||||
await expect
|
||||
.poll(() => getActiveItemId(page), { timeout: 45000 })
|
||||
.not.toBe(null);
|
||||
const runId = getActiveItemId(page);
|
||||
expect(runId, "run id should be present after starting task").toBeTruthy();
|
||||
await expect
|
||||
.poll(() => libraryPage.getRunStatus(), { timeout: 45000 })
|
||||
.toBe("running");
|
||||
|
||||
const stopTaskButton = page.getByRole("button", { name: /Stop task/i });
|
||||
await expect(stopTaskButton).toBeVisible({ timeout: 30000 });
|
||||
const stopResponsePromise = page.waitForResponse(
|
||||
(response) =>
|
||||
response.request().method() === "POST" &&
|
||||
response
|
||||
.url()
|
||||
.includes(`/api/graphs/${graphId}/executions/${runId}/stop`),
|
||||
{ timeout: 15000 },
|
||||
);
|
||||
await stopTaskButton.click();
|
||||
const stopResponse = await stopResponsePromise;
|
||||
|
||||
expect(stopResponse.ok(), "stop run API should succeed").toBe(true);
|
||||
await expect(page.getByText("Run stopped")).toBeVisible({ timeout: 15000 });
|
||||
await expect
|
||||
.poll(
|
||||
async () => {
|
||||
const status = await getExecutionStatusFromApi(
|
||||
page,
|
||||
graphId,
|
||||
String(runId),
|
||||
);
|
||||
return STOPPED_RUN_STATUSES.has(status) ? status : "running";
|
||||
},
|
||||
{ timeout: 45000 },
|
||||
)
|
||||
.not.toBe("running");
|
||||
});
|
||||
|
||||
test("library happy path: user can run a saved agent and verify expected output", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(150000);
|
||||
|
||||
const agentName = createUniqueAgentName("E2E Expected Output Agent");
|
||||
const outputName = `e2e-result-${Date.now()}`;
|
||||
await createDeterministicCalculatorSavedAgent(page, agentName, outputName);
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
await libraryPage.waitForRunToComplete();
|
||||
await dismissFeedbackDialog(page);
|
||||
|
||||
await libraryPage.assertRunProducedOutput();
|
||||
await libraryPage.assertRunOutputValue(outputName, /^2(?:\.0+)?$/);
|
||||
await expect
|
||||
.poll(() => libraryPage.getRunStatus(), { timeout: 15000 })
|
||||
.toBe("completed");
|
||||
});
|
||||
|
||||
test("library happy path: user can edit a saved agent from Library and keep changes after refresh", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(150000);
|
||||
|
||||
const { agentName } = await createAndSaveDeterministicOutputAgent(
|
||||
page,
|
||||
"E2E Edit Persist Agent",
|
||||
);
|
||||
const editedValue = `edited-value-${Date.now()}`;
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await page.goto("/library");
|
||||
await libraryPage.waitForAgentsToLoad();
|
||||
await libraryPage.searchAgents(agentName);
|
||||
await libraryPage.waitForAgentsToLoad();
|
||||
|
||||
const agentCard = page
|
||||
.getByTestId("library-agent-card")
|
||||
.filter({ hasText: agentName })
|
||||
.first();
|
||||
await expect(agentCard).toBeVisible({ timeout: 15000 });
|
||||
|
||||
const popupPromise = page
|
||||
.context()
|
||||
.waitForEvent("page", { timeout: 10000 })
|
||||
.catch(() => null);
|
||||
await agentCard
|
||||
.getByTestId("library-agent-card-open-in-builder-link")
|
||||
.first()
|
||||
.click();
|
||||
const builderPage = (await popupPromise) ?? page;
|
||||
|
||||
const builderTabPage = new BuildPage(builderPage);
|
||||
await builderTabPage.waitForNodeOnCanvas();
|
||||
await builderTabPage.fillBlockInputByPlaceholder(
|
||||
"Enter string value...",
|
||||
editedValue,
|
||||
0,
|
||||
);
|
||||
|
||||
await builderPage.getByTestId("save-control-save-button").click();
|
||||
const saveAgentButton = builderPage.getByRole("button", {
|
||||
name: "Save Agent",
|
||||
});
|
||||
if (await saveAgentButton.isVisible({ timeout: 3000 }).catch(() => false)) {
|
||||
await expect(saveAgentButton).toBeEnabled({ timeout: 10000 });
|
||||
await saveAgentButton.click();
|
||||
await expect(saveAgentButton).toBeHidden({ timeout: 15000 });
|
||||
}
|
||||
|
||||
await builderPage.reload();
|
||||
await builderTabPage.waitForNodeOnCanvas();
|
||||
await expect(
|
||||
builderTabPage
|
||||
.getNodeLocator(0)
|
||||
.locator('input[placeholder="Enter string value..."]'),
|
||||
).toHaveValue(editedValue);
|
||||
|
||||
if (builderPage !== page) {
|
||||
await builderPage.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("library happy path: user can rerun a completed task from the Library agent page", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
const { agentName } =
|
||||
await buildPage.createAndSaveSimpleAgent("E2E Rerun Agent");
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
await libraryPage.waitForRunToComplete();
|
||||
await dismissFeedbackDialog(page);
|
||||
|
||||
const rerunTaskButton = page.getByRole("button", { name: /Rerun task/i });
|
||||
await expect(rerunTaskButton).toBeVisible({ timeout: 45000 });
|
||||
|
||||
await expect
|
||||
.poll(() => getActiveItemId(page), { timeout: 45000 })
|
||||
.not.toBe(null);
|
||||
|
||||
const initialRunId = getActiveItemId(page);
|
||||
expect(initialRunId).toBeTruthy();
|
||||
|
||||
await rerunTaskButton.click();
|
||||
|
||||
await expect(page.getByText("Run started", { exact: true })).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
|
||||
await expect
|
||||
.poll(() => getActiveItemId(page), { timeout: 45000 })
|
||||
.not.toBe(initialRunId);
|
||||
|
||||
await libraryPage.waitForRunToComplete();
|
||||
|
||||
// Simple agent has no AgentOutputBlock — verify run completion only.
|
||||
const runStatus = await libraryPage.getRunStatus();
|
||||
expect(runStatus).toBe("completed");
|
||||
});
|
||||
|
||||
test("library happy path: user can delete a completed task from the run sidebar", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
const { agentName } = await buildPage.createAndSaveSimpleAgent(
|
||||
"E2E Delete Task Agent",
|
||||
);
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
await libraryPage.waitForRunToComplete();
|
||||
await dismissFeedbackDialog(page);
|
||||
|
||||
// Open the per-task actions dropdown ("More actions" three-dot button)
|
||||
// and use the menu's Delete task option to remove the run.
|
||||
const moreActionsButton = page
|
||||
.getByRole("button", { name: "More actions" })
|
||||
.first();
|
||||
await expect(moreActionsButton).toBeVisible({ timeout: 15000 });
|
||||
await moreActionsButton.click();
|
||||
|
||||
await page.getByRole("menuitem", { name: /Delete( this)? task/i }).click();
|
||||
|
||||
const confirmDialog = page.getByRole("dialog", { name: /Delete task/i });
|
||||
await expect(confirmDialog).toBeVisible({ timeout: 10000 });
|
||||
await confirmDialog.getByRole("button", { name: /^Delete Task$/ }).click();
|
||||
|
||||
// Toast confirms the backend actually deleted (not just dialog closed).
|
||||
await expect(page.getByText("Task deleted", { exact: true })).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
|
||||
// Sidebar should drop the only run, returning the page to initial
|
||||
// task-entry state.
|
||||
await expect(
|
||||
page.getByRole("button", { name: /^(Setup your task|New task)$/i }),
|
||||
).toBeVisible({ timeout: 15000 });
|
||||
});
|
||||
|
||||
test("library happy path: user can open the agent in builder from the exact runner customise-agent path", async ({
|
||||
page,
|
||||
context,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const buildPage = new BuildPage(page);
|
||||
const { agentName } = await buildPage.createAndSaveSimpleAgent(
|
||||
"E2E View Task Agent",
|
||||
);
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
await libraryPage.waitForRunToComplete();
|
||||
await dismissFeedbackDialog(page);
|
||||
|
||||
// The "View task details" eye-icon button on a completed run opens the
|
||||
// agent in the builder in a new tab. This exercises the runner → builder
|
||||
// navigation that QA item #22 ("Customise Agent" from Runner UI) covers.
|
||||
const selectedRunId = getActiveItemId(page);
|
||||
expect(selectedRunId).toBeTruthy();
|
||||
|
||||
const viewTaskButton = page
|
||||
.locator('[aria-label="View task details"]')
|
||||
.first();
|
||||
await expect(viewTaskButton).toBeVisible({ timeout: 15000 });
|
||||
const customiseAgentHref = await viewTaskButton.getAttribute("href");
|
||||
expect(customiseAgentHref).toContain("flowID=");
|
||||
expect(customiseAgentHref).toContain("flowVersion=");
|
||||
expect(customiseAgentHref).toContain(`flowExecutionID=${selectedRunId}`);
|
||||
|
||||
const popupPromise = context.waitForEvent("page", { timeout: 15000 });
|
||||
await viewTaskButton.click();
|
||||
const builderTab = await popupPromise;
|
||||
|
||||
await builderTab.waitForLoadState("domcontentloaded");
|
||||
await expect(builderTab).toHaveURL(/\/build/);
|
||||
await expect(builderTab).toHaveURL(
|
||||
new RegExp(`flowExecutionID=${selectedRunId}`),
|
||||
);
|
||||
|
||||
// Verify the builder canvas actually rendered with the agent's nodes —
|
||||
// a navigation that lands on /build but never paints the graph would
|
||||
// otherwise pass on URL alone.
|
||||
const builderTabPage = new BuildPage(builderTab);
|
||||
await builderTabPage.waitForNodeOnCanvas();
|
||||
expect(await builderTabPage.getNodeCount()).toBeGreaterThan(0);
|
||||
|
||||
await builderTab.close();
|
||||
});
|
||||
@@ -0,0 +1,48 @@
|
||||
import { expect, test } from "./coverage-fixture";
|
||||
import { E2E_AUTH_STATES } from "./credentials/accounts";
|
||||
import {
|
||||
clickRunButton,
|
||||
dismissFeedbackDialog,
|
||||
LibraryPage,
|
||||
} from "./pages/library.page";
|
||||
import { MarketplacePage } from "./pages/marketplace.page";
|
||||
|
||||
test.use({ storageState: E2E_AUTH_STATES.marketplace });
|
||||
|
||||
test("marketplace happy path: user can browse Marketplace and open an agent detail page", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(90000);
|
||||
|
||||
const marketplacePage = new MarketplacePage(page);
|
||||
await marketplacePage.openFeaturedAgent();
|
||||
|
||||
await expect(page.getByTestId("agent-description")).toBeVisible();
|
||||
});
|
||||
|
||||
test("marketplace happy path: user can add a Marketplace agent to Library and run it", async ({
|
||||
page,
|
||||
}) => {
|
||||
test.setTimeout(120000);
|
||||
|
||||
const marketplacePage = new MarketplacePage(page);
|
||||
await marketplacePage.openRunnableAgent();
|
||||
|
||||
const agentName = await page.getByTestId("agent-title").innerText();
|
||||
|
||||
await page.getByTestId("agent-add-library-button").click();
|
||||
await expect(page.getByText("Redirecting to your library...")).toBeVisible();
|
||||
await expect(page).toHaveURL(/\/library\/agents\//);
|
||||
|
||||
const libraryPage = new LibraryPage(page);
|
||||
await libraryPage.openSavedAgent(agentName);
|
||||
await clickRunButton(page);
|
||||
|
||||
await libraryPage.waitForRunToComplete();
|
||||
await dismissFeedbackDialog(page);
|
||||
|
||||
const runStatus = await libraryPage.getRunStatus();
|
||||
expect(runStatus).toBe("completed");
|
||||
await libraryPage.assertRunProducedOutput();
|
||||
await libraryPage.assertFirstRunOutputValue(/^\d+(?:\.0+)?$/);
|
||||
});
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user