mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform-cost): address PR review — deduplicate filter logic, skip redundant query, improve frontend
Backend: - Extract _build_raw_where() helper so raw SQL and Prisma WHERE share filter logic (review item #4 — duplicated filter logic) - Skip redundant total_agg_no_tracking_type_groups query when tracking_type is None since it duplicates total_agg_groups (item #3) - Convert CostBucket from TypedDict to BaseModel for consistency (nit #1) - Replace fragile 8-way positional tuple unpack with indexed list access Frontend: - Make 12 SummaryCards data-driven via a cards config array (item #5) - Use friendlier percentile labels: Typical/Upper/High/Peak Cost (P50/P75/P95/P99) - Update test fixtures with all new dashboard fields (item #1) - Add test assertions for new summary card labels, cost buckets, token values, and user table columns
This commit is contained in:
@@ -7,7 +7,6 @@ from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
@@ -164,7 +163,7 @@ class CostLogRow(BaseModel):
|
||||
cache_creation_tokens: int | None = None
|
||||
|
||||
|
||||
class CostBucket(TypedDict):
|
||||
class CostBucket(BaseModel):
|
||||
bucket: str
|
||||
count: int
|
||||
|
||||
@@ -244,6 +243,66 @@ def _build_prisma_where(
|
||||
return where
|
||||
|
||||
|
||||
def _build_raw_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list]:
|
||||
"""Build a parameterised WHERE clause for raw SQL queries.
|
||||
|
||||
Mirrors the filter logic of ``_build_prisma_where`` so there is a single
|
||||
source of truth for which columns are filtered and how. The first clause
|
||||
always restricts to ``cost_usd`` tracking type unless *tracking_type* is
|
||||
explicitly provided by the caller.
|
||||
"""
|
||||
params: list = []
|
||||
clauses: list[str] = []
|
||||
idx = 1
|
||||
|
||||
# Always filter by tracking type — defaults to cost_usd for percentile /
|
||||
# bucket queries that only make sense on cost-denominated rows.
|
||||
tt = tracking_type if tracking_type is not None else "cost_usd"
|
||||
clauses.append(f'"trackingType" = ${idx}')
|
||||
params.append(tt)
|
||||
idx += 1
|
||||
|
||||
if start is not None:
|
||||
clauses.append(f'"createdAt" >= ${idx}')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
|
||||
if end is not None:
|
||||
clauses.append(f'"createdAt" <= ${idx}')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
|
||||
if provider is not None:
|
||||
clauses.append(f'"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
|
||||
if user_id is not None:
|
||||
clauses.append(f'"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
|
||||
if model is not None:
|
||||
clauses.append(f'"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
|
||||
if block_name is not None:
|
||||
clauses.append(f'LOWER("blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses), params)
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
async def get_platform_cost_dashboard(
|
||||
start: datetime | None = None,
|
||||
@@ -291,59 +350,14 @@ async def get_platform_cost_dashboard(
|
||||
}
|
||||
|
||||
# Build parameterised WHERE clause for the raw SQL percentile/bucket
|
||||
# queries so they honour all active dashboard filters, not just start date.
|
||||
raw_params: list = [start]
|
||||
raw_where_clauses = [
|
||||
"\"trackingType\" = 'cost_usd'",
|
||||
'"createdAt" >= $1',
|
||||
]
|
||||
param_idx = 2 # $1 is already start
|
||||
# queries. Uses _build_raw_where so filter logic is shared with
|
||||
# _build_prisma_where and only maintained in one place.
|
||||
raw_where, raw_params = _build_raw_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
if end is not None:
|
||||
raw_where_clauses.append(f'"createdAt" <= ${param_idx}')
|
||||
raw_params.append(end)
|
||||
param_idx += 1
|
||||
|
||||
if provider is not None:
|
||||
raw_where_clauses.append(f'"provider" = ${param_idx}')
|
||||
raw_params.append(provider.lower())
|
||||
param_idx += 1
|
||||
|
||||
if user_id is not None:
|
||||
raw_where_clauses.append(f'"userId" = ${param_idx}')
|
||||
raw_params.append(user_id)
|
||||
param_idx += 1
|
||||
|
||||
if model is not None:
|
||||
raw_where_clauses.append(f'"model" = ${param_idx}')
|
||||
raw_params.append(model)
|
||||
param_idx += 1
|
||||
|
||||
if block_name is not None:
|
||||
raw_where_clauses.append(f'LOWER("blockName") = LOWER(${param_idx})')
|
||||
raw_params.append(block_name)
|
||||
param_idx += 1
|
||||
|
||||
# If the caller supplied a specific tracking_type filter, replace the
|
||||
# hardcoded cost_usd clause so the percentile/bucket queries respect it.
|
||||
if tracking_type is not None:
|
||||
raw_where_clauses[0] = f'"trackingType" = ${param_idx}'
|
||||
raw_params.append(tracking_type)
|
||||
param_idx += 1
|
||||
|
||||
raw_where = " AND ".join(raw_where_clauses)
|
||||
|
||||
# Run all eight aggregation queries in parallel.
|
||||
(
|
||||
by_provider_groups,
|
||||
by_user_groups,
|
||||
by_user_tracking_groups,
|
||||
total_user_groups,
|
||||
total_agg_groups,
|
||||
total_agg_no_tracking_type_groups,
|
||||
percentile_rows,
|
||||
bucket_rows,
|
||||
) = await asyncio.gather(
|
||||
# Queries that always run regardless of tracking_type filter.
|
||||
common_queries = [
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
PrismaLog.prisma().group_by(
|
||||
@@ -386,20 +400,6 @@ async def get_platform_cost_dashboard(
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate (no tracking_type filter): used to compute
|
||||
# cost_bearing_requests and token_bearing_requests denominators so
|
||||
# global avg stats remain meaningful when the caller filters the main
|
||||
# view by a specific tracking_type (e.g. 'tokens').
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Percentile distribution of cost per request (respects all filters).
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
@@ -440,6 +440,43 @@ async def get_platform_cost_dashboard(
|
||||
' ORDER BY MIN("costMicrodollars")',
|
||||
*raw_params,
|
||||
),
|
||||
]
|
||||
|
||||
# Only run the unfiltered aggregate query when tracking_type is set;
|
||||
# when tracking_type is None, the filtered query already contains all
|
||||
# tracking types and reusing it avoids a redundant full aggregation.
|
||||
if tracking_type is not None:
|
||||
common_queries.append(
|
||||
# Total aggregate (no tracking_type filter): used to compute
|
||||
# cost_bearing_requests and token_bearing_requests denominators so
|
||||
# global avg stats remain meaningful when the caller filters the
|
||||
# main view by a specific tracking_type (e.g. 'tokens').
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
)
|
||||
)
|
||||
|
||||
results = await asyncio.gather(*common_queries)
|
||||
|
||||
# Unpack results by name for clarity.
|
||||
by_provider_groups = results[0]
|
||||
by_user_groups = results[1]
|
||||
by_user_tracking_groups = results[2]
|
||||
total_user_groups = results[3]
|
||||
total_agg_groups = results[4]
|
||||
percentile_rows = results[5]
|
||||
bucket_rows = results[6]
|
||||
# When tracking_type is None, the filtered and unfiltered queries are
|
||||
# identical — reuse total_agg_groups to avoid the extra DB round-trip.
|
||||
total_agg_no_tracking_type_groups = (
|
||||
results[7] if tracking_type is not None else total_agg_groups
|
||||
)
|
||||
|
||||
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
|
||||
|
||||
@@ -29,6 +29,16 @@ const emptyDashboard: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 0,
|
||||
total_requests: 0,
|
||||
total_users: 0,
|
||||
total_input_tokens: 0,
|
||||
total_output_tokens: 0,
|
||||
avg_input_tokens_per_request: 0,
|
||||
avg_output_tokens_per_request: 0,
|
||||
avg_cost_microdollars_per_request: 0,
|
||||
cost_p50_microdollars: 0,
|
||||
cost_p75_microdollars: 0,
|
||||
cost_p95_microdollars: 0,
|
||||
cost_p99_microdollars: 0,
|
||||
cost_buckets: [],
|
||||
by_provider: [],
|
||||
by_user: [],
|
||||
};
|
||||
@@ -47,6 +57,20 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_cost_microdollars: 5_000_000,
|
||||
total_requests: 100,
|
||||
total_users: 5,
|
||||
total_input_tokens: 150000,
|
||||
total_output_tokens: 60000,
|
||||
avg_input_tokens_per_request: 2500,
|
||||
avg_output_tokens_per_request: 1000,
|
||||
avg_cost_microdollars_per_request: 83333,
|
||||
cost_p50_microdollars: 50000,
|
||||
cost_p75_microdollars: 100000,
|
||||
cost_p95_microdollars: 250000,
|
||||
cost_p99_microdollars: 500000,
|
||||
cost_buckets: [
|
||||
{ bucket: "$0-0.50", count: 80 },
|
||||
{ bucket: "$0.50-1", count: 15 },
|
||||
{ bucket: "$1-2", count: 5 },
|
||||
],
|
||||
by_provider: [
|
||||
{
|
||||
provider: "openai",
|
||||
@@ -75,6 +99,7 @@ const dashboardWithData: PlatformCostDashboard = {
|
||||
total_input_tokens: 50000,
|
||||
total_output_tokens: 20000,
|
||||
request_count: 60,
|
||||
cost_bearing_request_count: 40,
|
||||
},
|
||||
],
|
||||
};
|
||||
@@ -138,7 +163,8 @@ describe("PlatformCostContent", () => {
|
||||
// "Known Cost" appears in both the SummaryCard and the ProviderTable header
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
// All cost summary cards (Known Cost, Estimated Total, Avg Cost, P50/P75/P95/P99) show $0.0000
|
||||
// All cost summary cards (Known Cost, Estimated Total, Avg Cost,
|
||||
// Typical/Upper/High/Peak Cost) show $0.0000
|
||||
const zeroCostItems = screen.getAllByText("$0.0000");
|
||||
expect(zeroCostItems.length).toBe(7);
|
||||
expect(screen.getByText("No cost data yet")).toBeDefined();
|
||||
@@ -227,10 +253,83 @@ describe("PlatformCostContent", () => {
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Original 4 cards
|
||||
expect(screen.getAllByText("Known Cost").length).toBeGreaterThanOrEqual(1);
|
||||
expect(screen.getByText("Estimated Total")).toBeDefined();
|
||||
expect(screen.getByText("Total Requests")).toBeDefined();
|
||||
expect(screen.getByText("Active Users")).toBeDefined();
|
||||
// New average/token cards
|
||||
expect(screen.getByText("Avg Cost / Request")).toBeDefined();
|
||||
expect(screen.getByText("Avg Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Avg Output Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Total Tokens")).toBeDefined();
|
||||
// Percentile cards (friendlier labels)
|
||||
expect(screen.getByText("Typical Cost (P50)")).toBeDefined();
|
||||
expect(screen.getByText("Upper Cost (P75)")).toBeDefined();
|
||||
expect(screen.getByText("High Cost (P95)")).toBeDefined();
|
||||
expect(screen.getByText("Peak Cost (P99)")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders cost distribution buckets", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
expect(screen.getByText("Cost Distribution by Bucket")).toBeDefined();
|
||||
expect(screen.getByText("$0-0.50")).toBeDefined();
|
||||
expect(screen.getByText("$0.50-1")).toBeDefined();
|
||||
expect(screen.getByText("$1-2")).toBeDefined();
|
||||
expect(screen.getByText("80")).toBeDefined();
|
||||
expect(screen.getByText("15")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders new summary card values from fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent();
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// Avg Input Tokens: 2500 formatted
|
||||
expect(screen.getByText("2,500")).toBeDefined();
|
||||
// Avg Output Tokens: 1000 formatted
|
||||
expect(screen.getByText("1,000")).toBeDefined();
|
||||
// P50 cost: 50000 microdollars = $0.0500
|
||||
expect(screen.getByText("$0.0500")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders user table avg cost column with fixture data", async () => {
|
||||
mockUseGetDashboard.mockReturnValue({
|
||||
data: dashboardWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
mockUseGetLogs.mockReturnValue({
|
||||
data: logsWithData,
|
||||
isLoading: false,
|
||||
});
|
||||
renderComponent({ tab: "by-user" });
|
||||
await waitFor(() =>
|
||||
expect(document.querySelector(".animate-pulse")).toBeNull(),
|
||||
);
|
||||
// User table should show Avg Cost / Req header
|
||||
expect(screen.getByText("Avg Cost / Req")).toBeDefined();
|
||||
// Input/Output token columns
|
||||
expect(screen.getByText("Input Tokens")).toBeDefined();
|
||||
expect(screen.getByText("Output Tokens")).toBeDefined();
|
||||
});
|
||||
|
||||
it("renders filter inputs", async () => {
|
||||
|
||||
@@ -206,7 +206,8 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
{loading ? (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="grid grid-cols-2 gap-4 md:grid-cols-4">
|
||||
{[...Array(12)].map((_, i) => (
|
||||
{/* 12 skeleton placeholders — one per summary card */}
|
||||
{Array.from({ length: 12 }, (_, i) => (
|
||||
<Skeleton key={i} className="h-20 rounded-lg" />
|
||||
))}
|
||||
</div>
|
||||
@@ -218,80 +219,101 @@ export function PlatformCostContent({ searchParams }: Props) {
|
||||
<>
|
||||
{dashboard && (
|
||||
<>
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
<SummaryCard
|
||||
label="Known Cost"
|
||||
value={formatMicrodollars(dashboard.total_cost_microdollars)}
|
||||
subtitle="From providers that report USD cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Estimated Total"
|
||||
value={formatMicrodollars(totalEstimatedCost)}
|
||||
subtitle="Including per-run cost estimates"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Requests"
|
||||
value={dashboard.total_requests.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Active Users"
|
||||
value={dashboard.total_users.toLocaleString()}
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Avg Cost / Request"
|
||||
value={formatMicrodollars(
|
||||
dashboard.avg_cost_microdollars_per_request ?? 0,
|
||||
)}
|
||||
subtitle="Known cost divided by cost-bearing requests"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Avg Input Tokens"
|
||||
value={Math.round(
|
||||
dashboard.avg_input_tokens_per_request ?? 0,
|
||||
).toLocaleString()}
|
||||
subtitle="Prompt tokens per request (context size)"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Avg Output Tokens"
|
||||
value={Math.round(
|
||||
dashboard.avg_output_tokens_per_request ?? 0,
|
||||
).toLocaleString()}
|
||||
subtitle="Completion tokens per request (response length)"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="Total Tokens"
|
||||
value={`${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`}
|
||||
subtitle="Prompt vs completion token split"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="P50 Cost / Request"
|
||||
value={formatMicrodollars(
|
||||
dashboard.cost_p50_microdollars ?? 0,
|
||||
)}
|
||||
subtitle="Median cost per request"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="P75 Cost / Request"
|
||||
value={formatMicrodollars(
|
||||
dashboard.cost_p75_microdollars ?? 0,
|
||||
)}
|
||||
subtitle="75th percentile cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="P95 Cost / Request"
|
||||
value={formatMicrodollars(
|
||||
dashboard.cost_p95_microdollars ?? 0,
|
||||
)}
|
||||
subtitle="95th percentile cost"
|
||||
/>
|
||||
<SummaryCard
|
||||
label="P99 Cost / Request"
|
||||
value={formatMicrodollars(
|
||||
dashboard.cost_p99_microdollars ?? 0,
|
||||
)}
|
||||
subtitle="99th percentile cost"
|
||||
/>
|
||||
</div>
|
||||
{(() => {
|
||||
const summaryCards: {
|
||||
label: string;
|
||||
value: string;
|
||||
subtitle?: string;
|
||||
}[] = [
|
||||
{
|
||||
label: "Known Cost",
|
||||
value: formatMicrodollars(
|
||||
dashboard.total_cost_microdollars,
|
||||
),
|
||||
subtitle: "From providers that report USD cost",
|
||||
},
|
||||
{
|
||||
label: "Estimated Total",
|
||||
value: formatMicrodollars(totalEstimatedCost),
|
||||
subtitle: "Including per-run cost estimates",
|
||||
},
|
||||
{
|
||||
label: "Total Requests",
|
||||
value: dashboard.total_requests.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Active Users",
|
||||
value: dashboard.total_users.toLocaleString(),
|
||||
},
|
||||
{
|
||||
label: "Avg Cost / Request",
|
||||
value: formatMicrodollars(
|
||||
dashboard.avg_cost_microdollars_per_request ?? 0,
|
||||
),
|
||||
subtitle: "Known cost divided by cost-bearing requests",
|
||||
},
|
||||
{
|
||||
label: "Avg Input Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_input_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle: "Prompt tokens per request (context size)",
|
||||
},
|
||||
{
|
||||
label: "Avg Output Tokens",
|
||||
value: Math.round(
|
||||
dashboard.avg_output_tokens_per_request ?? 0,
|
||||
).toLocaleString(),
|
||||
subtitle:
|
||||
"Completion tokens per request (response length)",
|
||||
},
|
||||
{
|
||||
label: "Total Tokens",
|
||||
value: `${formatTokens(dashboard.total_input_tokens ?? 0)} in / ${formatTokens(dashboard.total_output_tokens ?? 0)} out`,
|
||||
subtitle: "Prompt vs completion token split",
|
||||
},
|
||||
{
|
||||
label: "Typical Cost (P50)",
|
||||
value: formatMicrodollars(
|
||||
dashboard.cost_p50_microdollars ?? 0,
|
||||
),
|
||||
subtitle: "Median cost per request",
|
||||
},
|
||||
{
|
||||
label: "Upper Cost (P75)",
|
||||
value: formatMicrodollars(
|
||||
dashboard.cost_p75_microdollars ?? 0,
|
||||
),
|
||||
subtitle: "75th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "High Cost (P95)",
|
||||
value: formatMicrodollars(
|
||||
dashboard.cost_p95_microdollars ?? 0,
|
||||
),
|
||||
subtitle: "95th percentile cost",
|
||||
},
|
||||
{
|
||||
label: "Peak Cost (P99)",
|
||||
value: formatMicrodollars(
|
||||
dashboard.cost_p99_microdollars ?? 0,
|
||||
),
|
||||
subtitle: "99th percentile cost",
|
||||
},
|
||||
];
|
||||
return (
|
||||
<div className="grid grid-cols-2 gap-4 sm:grid-cols-3 md:grid-cols-4">
|
||||
{summaryCards.map((card) => (
|
||||
<SummaryCard
|
||||
key={card.label}
|
||||
label={card.label}
|
||||
value={card.value}
|
||||
subtitle={card.subtitle}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
})()}
|
||||
|
||||
{dashboard.cost_buckets && dashboard.cost_buckets.length > 0 && (
|
||||
<div className="rounded-lg border p-4">
|
||||
|
||||
Reference in New Issue
Block a user