fix(platform-cost): address PR review — deduplicate filter logic, skip redundant query, improve frontend

Backend:
- Extract _build_raw_where() helper so raw SQL and Prisma WHERE share
  filter logic (review item #4 — duplicated filter logic)
- Skip redundant total_agg_no_tracking_type_groups query when
  tracking_type is None since it duplicates total_agg_groups (item #3)
- Convert CostBucket from TypedDict to BaseModel for consistency (nit #1)
- Replace fragile 8-way positional tuple unpack with indexed list access

Frontend:
- Make 12 SummaryCards data-driven via a cards config array (item #5)
- Use friendlier percentile labels: Typical/Upper/High/Peak Cost (P50/P75/P95/P99)
- Update test fixtures with all new dashboard fields (item #1)
- Add test assertions for new summary card labels, cost buckets, token
  values, and user table columns
This commit is contained in:
majdyz
2026-04-13 05:16:55 +00:00
parent c51471a9df
commit 4cc8ef4409
3 changed files with 302 additions and 144 deletions

View File

@@ -7,7 +7,6 @@ from prisma.models import PlatformCostLog as PrismaLog
from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from typing_extensions import TypedDict
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
@@ -164,7 +163,7 @@ class CostLogRow(BaseModel):
cache_creation_tokens: int | None = None
class CostBucket(TypedDict):
class CostBucket(BaseModel):
bucket: str
count: int
@@ -244,6 +243,66 @@ def _build_prisma_where(
return where
def _build_raw_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list]:
"""Build a parameterised WHERE clause for raw SQL queries.
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
source of truth for which columns are filtered and how. The first clause
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
explicitly provided by the caller.
"""
params: list = []
clauses: list[str] = []
idx = 1
# Always filter by tracking type — defaults to cost_usd for percentile /
# bucket queries that only make sense on cost-denominated rows.
tt = tracking_type if tracking_type is not None else "cost_usd"
clauses.append(f'"trackingType" = ${idx}')
params.append(tt)
idx += 1
if start is not None:
clauses.append(f'"createdAt" >= ${idx}')
params.append(start)
idx += 1
if end is not None:
clauses.append(f'"createdAt" <= ${idx}')
params.append(end)
idx += 1
if provider is not None:
clauses.append(f'"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id is not None:
clauses.append(f'"userId" = ${idx}')
params.append(user_id)
idx += 1
if model is not None:
clauses.append(f'"model" = ${idx}')
params.append(model)
idx += 1
if block_name is not None:
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
return (" AND ".join(clauses), params)
@cached(ttl_seconds=30)
async def get_platform_cost_dashboard(
start: datetime | None = None,
@@ -291,59 +350,14 @@ async def get_platform_cost_dashboard(
}
# Build parameterised WHERE clause for the raw SQL percentile/bucket
# queries so they honour all active dashboard filters, not just start date.
raw_params: list = [start]
raw_where_clauses = [
"\"trackingType\" = 'cost_usd'",
'"createdAt" >= $1',
]
param_idx = 2 # $1 is already start
# queries. Uses _build_raw_where so filter logic is shared with
# _build_prisma_where and only maintained in one place.
raw_where, raw_params = _build_raw_where(
start, end, provider, user_id, model, block_name, tracking_type
)
if end is not None:
raw_where_clauses.append(f'"createdAt" <= ${param_idx}')
raw_params.append(end)
param_idx += 1
if provider is not None:
raw_where_clauses.append(f'"provider" = ${param_idx}')
raw_params.append(provider.lower())
param_idx += 1
if user_id is not None:
raw_where_clauses.append(f'"userId" = ${param_idx}')
raw_params.append(user_id)
param_idx += 1
if model is not None:
raw_where_clauses.append(f'"model" = ${param_idx}')
raw_params.append(model)
param_idx += 1
if block_name is not None:
raw_where_clauses.append(f'LOWER("blockName") = LOWER(${param_idx})')
raw_params.append(block_name)
param_idx += 1
# If the caller supplied a specific tracking_type filter, replace the
# hardcoded cost_usd clause so the percentile/bucket queries respect it.
if tracking_type is not None:
raw_where_clauses[0] = f'"trackingType" = ${param_idx}'
raw_params.append(tracking_type)
param_idx += 1
raw_where = " AND ".join(raw_where_clauses)
# Run all eight aggregation queries in parallel.
(
by_provider_groups,
by_user_groups,
by_user_tracking_groups,
total_user_groups,
total_agg_groups,
total_agg_no_tracking_type_groups,
percentile_rows,
bucket_rows,
) = await asyncio.gather(
# Queries that always run regardless of tracking_type filter.
common_queries = [
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
PrismaLog.prisma().group_by(
@@ -386,20 +400,6 @@ async def get_platform_cost_dashboard(
},
count=True,
),
# Total aggregate (no tracking_type filter): used to compute
# cost_bearing_requests and token_bearing_requests denominators so
# global avg stats remain meaningful when the caller filters the main
# view by a specific tracking_type (e.g. 'tokens').
PrismaLog.prisma().group_by(
by=["provider", "trackingType"],
where=where_no_tracking_type,
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
),
# Percentile distribution of cost per request (respects all filters).
query_raw_with_schema(
"SELECT"
@@ -440,6 +440,43 @@ async def get_platform_cost_dashboard(
' ORDER BY MIN("costMicrodollars")',
*raw_params,
),
]
# Only run the unfiltered aggregate query when tracking_type is set;
# when tracking_type is None, the filtered query already contains all
# tracking types and reusing it avoids a redundant full aggregation.
if tracking_type is not None:
common_queries.append(
# Total aggregate (no tracking_type filter): used to compute
# cost_bearing_requests and token_bearing_requests denominators so
# global avg stats remain meaningful when the caller filters the
# main view by a specific tracking_type (e.g. 'tokens').
PrismaLog.prisma().group_by(
by=["provider", "trackingType"],
where=where_no_tracking_type,
sum={
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
},
count=True,
)
)
results = await asyncio.gather(*common_queries)
# Unpack results by name for clarity.
by_provider_groups = results[0]
by_user_groups = results[1]
by_user_tracking_groups = results[2]
total_user_groups = results[3]
total_agg_groups = results[4]
percentile_rows = results[5]
bucket_rows = results[6]
# When tracking_type is None, the filtered and unfiltered queries are
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
total_agg_no_tracking_type_groups = (
results[7] if tracking_type is not None else total_agg_groups
)
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.

View File

@@ -29,6 +29,16 @@ const emptyDashboard: PlatformCostDashboard = {
total_cost_microdollars: 0,
total_requests: 0,
total_users: 0,
total_input_tokens: 0,
total_output_tokens: 0,
avg_input_tokens_per_request: 0,
avg_output_tokens_per_request: 0,
avg_cost_microdollars_per_request: 0,
cost_p50_microdollars: 0,
cost_p75_microdollars: 0,
cost_p95_microdollars: 0,
cost_p99_microdollars: 0,
cost_buckets: [],
by_provider: [],
by_user: [],
};
@@ -47,6 +57,20 @@ const dashboardWithData: PlatformCostDashboard = {
total_cost_microdollars: 5_000_000,
total_requests: 100,
total_users: 5,
total_input_tokens: 150000,
total_output_tokens: 60000,
avg_input_tokens_per_request: 2500,
avg_output_tokens_per_request: 1000,
avg_cost_microdollars_per_request: 83333,
cost_p50_microdollars: 50000,
cost_p75_microdollars: 100000,
cost_p95_microdollars: 250000,
cost_p99_microdollars: 500000,
cost_buckets: [
{ bucket: "$0-0.50", count: 80 },
{ bucket: "$0.50-1", count: 15 },
{ bucket: "$1-2", count: 5 },
],
by_provider: [
{
provider: "openai",
@@ -75,6 +99,7 @@ const dashboardWithData: PlatformCostDashboard = {
total_input_tokens: 50000,
total_output_tokens: 20000,
request_count: 60,
cost_bearing_request_count: 40,
},
],
};
@@ -138,7 +163,8 @@ describe("PlatformCostContent", () => {
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
// All cost summary cards (Known Cost, Estimated Total, Avg Cost, P50/P75/P95/P99) show $0.0000
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
// Typical/Upper/High/Peak Cost) show $0.0000
const zeroCostItems = screen.getAllByText("$0.0000");
expect(zeroCostItems.length).toBe(7);
expect(screen.getByText("No cost data yet")).toBeDefined();
@@ -227,10 +253,83 @@ describe("PlatformCostContent", () => {
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Original 4 cards
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
expect(screen.getByText("Estimated Total")).toBeDefined();
expect(screen.getByText("Total Requests")).toBeDefined();
expect(screen.getByText("Active Users")).toBeDefined();
// New average/token cards
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
expect(screen.getByText("Total Tokens")).toBeDefined();
// Percentile cards (friendlier labels)
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
expect(screen.getByText("High Cost (P95)")).toBeDefined();
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
});
it("renders cost distribution buckets", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
expect(screen.getByText("$0-0.50")).toBeDefined();
expect(screen.getByText("$0.50-1")).toBeDefined();
expect(screen.getByText("$1-2")).toBeDefined();
expect(screen.getByText("80")).toBeDefined();
expect(screen.getByText("15")).toBeDefined();
});
it("renders new summary card values from fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent();
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// Avg Input Tokens: 2500 formatted
expect(screen.getByText("2,500")).toBeDefined();
// Avg Output Tokens: 1000 formatted
expect(screen.getByText("1,000")).toBeDefined();
// P50 cost: 50000 microdollars = $0.0500
expect(screen.getByText("$0.0500")).toBeDefined();
});
it("renders user table avg cost column with fixture data", async () => {
mockUseGetDashboard.mockReturnValue({
data: dashboardWithData,
isLoading: false,
});
mockUseGetLogs.mockReturnValue({
data: logsWithData,
isLoading: false,
});
renderComponent({ tab: "by-user" });
await waitFor(() =>
expect(document.querySelector(".animate-pulse")).toBeNull(),
);
// User table should show Avg Cost / Req header
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
// Input/Output token columns
expect(screen.getByText("Input Tokens")).toBeDefined();
expect(screen.getByText("Output Tokens")).toBeDefined();
});
it("renders filter inputs", async () => {

View File

@@ -206,7 +206,8 @@ export function PlatformCostContent({ searchParams }: Props) {
{loading ? (
<div className="flex flex-col gap-4">
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
{[...Array(12)].map((_, i) => (
{/* 12 skeleton placeholders — one per summary card */}
{Array.from({ length: 12 }, (_, i) => (
<Skeleton key={i} className="h-20 rounded-lg" />
))}
</div>
@@ -218,80 +219,101 @@ export function PlatformCostContent({ searchParams }: Props) {
<>
{dashboard && (
<>
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
<SummaryCard
label="Known Cost"
value={formatMicrodollars(dashboard.total_cost_microdollars)}
subtitle="From providers that report USD cost"
/>
<SummaryCard
label="Estimated Total"
value={formatMicrodollars(totalEstimatedCost)}
subtitle="Including per-run cost estimates"
/>
<SummaryCard
label="Total Requests"
value={dashboard.total_requests.toLocaleString()}
/>
<SummaryCard
label="Active Users"
value={dashboard.total_users.toLocaleString()}
/>
<SummaryCard
label="Avg Cost / Request"
value={formatMicrodollars(
dashboard.avg_cost_microdollars_per_request ?? 0,
)}
subtitle="Known cost divided by cost-bearing requests"
/>
<SummaryCard
label="Avg Input Tokens"
value={Math.round(
dashboard.avg_input_tokens_per_request ?? 0,
).toLocaleString()}
subtitle="Prompt tokens per request (context size)"
/>
<SummaryCard
label="Avg Output Tokens"
value={Math.round(
dashboard.avg_output_tokens_per_request ?? 0,
).toLocaleString()}
subtitle="Completion tokens per request (response length)"
/>
<SummaryCard
label="Total Tokens"
value={`${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`}
subtitle="Prompt vs completion token split"
/>
<SummaryCard
label="P50 Cost / Request"
value={formatMicrodollars(
dashboard.cost_p50_microdollars ?? 0,
)}
subtitle="Median cost per request"
/>
<SummaryCard
label="P75 Cost / Request"
value={formatMicrodollars(
dashboard.cost_p75_microdollars ?? 0,
)}
subtitle="75th percentile cost"
/>
<SummaryCard
label="P95 Cost / Request"
value={formatMicrodollars(
dashboard.cost_p95_microdollars ?? 0,
)}
subtitle="95th percentile cost"
/>
<SummaryCard
label="P99 Cost / Request"
value={formatMicrodollars(
dashboard.cost_p99_microdollars ?? 0,
)}
subtitle="99th percentile cost"
/>
</div>
{(() => {
const summaryCards: {
label: string;
value: string;
subtitle?: string;
}[] = [
{
label: "Known Cost",
value: formatMicrodollars(
dashboard.total_cost_microdollars,
),
subtitle: "From providers that report USD cost",
},
{
label: "Estimated Total",
value: formatMicrodollars(totalEstimatedCost),
subtitle: "Including per-run cost estimates",
},
{
label: "Total Requests",
value: dashboard.total_requests.toLocaleString(),
},
{
label: "Active Users",
value: dashboard.total_users.toLocaleString(),
},
{
label: "Avg Cost / Request",
value: formatMicrodollars(
dashboard.avg_cost_microdollars_per_request ?? 0,
),
subtitle: "Known cost divided by cost-bearing requests",
},
{
label: "Avg Input Tokens",
value: Math.round(
dashboard.avg_input_tokens_per_request ?? 0,
).toLocaleString(),
subtitle: "Prompt tokens per request (context size)",
},
{
label: "Avg Output Tokens",
value: Math.round(
dashboard.avg_output_tokens_per_request ?? 0,
).toLocaleString(),
subtitle:
"Completion tokens per request (response length)",
},
{
label: "Total Tokens",
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
subtitle: "Prompt vs completion token split",
},
{
label: "Typical Cost (P50)",
value: formatMicrodollars(
dashboard.cost_p50_microdollars ?? 0,
),
subtitle: "Median cost per request",
},
{
label: "Upper Cost (P75)",
value: formatMicrodollars(
dashboard.cost_p75_microdollars ?? 0,
),
subtitle: "75th percentile cost",
},
{
label: "High Cost (P95)",
value: formatMicrodollars(
dashboard.cost_p95_microdollars ?? 0,
),
subtitle: "95th percentile cost",
},
{
label: "Peak Cost (P99)",
value: formatMicrodollars(
dashboard.cost_p99_microdollars ?? 0,
),
subtitle: "99th percentile cost",
},
];
return (
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
{summaryCards.map((card) => (
<SummaryCard
key={card.label}
label={card.label}
value={card.value}
subtitle={card.subtitle}
/>
))}
</div>
);
})()}
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
<div className="rounded-lg border p-4">