fix(platform-cost): fix per-user avg cost denominator, NULL bucket, tracking_type filter gap

- Add `cost_bearing_request_count` to `UserCostSummary` via a new
  group-by-(userId,trackingType) query; `UserTable` now divides by
  this count instead of the mixed `request_count`, eliminating
  denominator dilution for users with both tokens and cost_usd rows
- Guard histogram CASE against NULL costMicrodollars (NULL < N → unknown
  falls to ELSE '$10+'); add `AND "costMicrodollars" IS NOT NULL` to
  the histogram WHERE so NULL rows are excluded instead of bucketed
- Respect the `tracking_type` dashboard filter in raw SQL percentile
  and bucket queries; previously the filter was hardcoded to 'cost_usd'
  even when the caller passed tracking_type='tokens', making those
  queries return inconsistent data relative to the ORM queries
- Add p75 and p99 assertions to test_returns_dashboard_with_data
- Update openapi.json and generated TS model for new field
This commit is contained in:
majdyz
2026-04-13 02:10:40 +00:00
parent ac973396a2
commit 187b4596e0
5 changed files with 68 additions and 12 deletions

View File

@@ -3,12 +3,11 @@ import logging
from datetime import datetime, timedelta, timezone
from typing import Any
from typing_extensions import TypedDict
from prisma.models import PlatformCostLog as PrismaLog
from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from typing_extensions import TypedDict
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
@@ -143,6 +142,7 @@ class UserCostSummary(BaseModel):
total_input_tokens: int
total_output_tokens: int
request_count: int
cost_bearing_request_count: int = 0
class CostLogRow(BaseModel):
@@ -286,7 +286,7 @@ async def get_platform_cost_dashboard(
# queries so they honour all active dashboard filters, not just start date.
raw_params: list = [start]
raw_where_clauses = [
'"trackingType" = \'cost_usd\'',
"\"trackingType\" = 'cost_usd'",
'"createdAt" >= $1',
]
param_idx = 2 # $1 is already start
@@ -316,12 +316,20 @@ async def get_platform_cost_dashboard(
raw_params.append(block_name)
param_idx += 1
# If the caller supplied a specific tracking_type filter, replace the
# hardcoded cost_usd clause so the percentile/bucket queries respect it.
if tracking_type is not None:
raw_where_clauses[0] = f'"trackingType" = ${param_idx}'
raw_params.append(tracking_type)
param_idx += 1
raw_where = " AND ".join(raw_where_clauses)
# Run all six aggregation queries in parallel.
# Run all seven aggregation queries in parallel.
(
by_provider_groups,
by_user_groups,
by_user_tracking_groups,
total_user_groups,
total_agg_groups,
percentile_rows,
@@ -342,6 +350,13 @@ async def get_platform_cost_dashboard(
sum=sum_fields,
count=True,
),
# Per-user cost-bearing request count: group by (userId, trackingType)
# so we can compute the correct denominator for per-user avg cost.
PrismaLog.prisma().group_by(
by=["userId", "trackingType"],
where=where,
count=True,
),
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
@@ -376,6 +391,8 @@ async def get_platform_cost_dashboard(
*raw_params,
),
# Histogram buckets for cost distribution (respects all filters).
# NULL costMicrodollars is excluded explicitly to prevent such rows
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
query_raw_with_schema(
"SELECT"
" CASE"
@@ -393,7 +410,7 @@ async def get_platform_cost_dashboard(
" END as bucket,"
" COUNT(*) as count"
' FROM {schema_prefix}"PlatformCostLog"'
f" WHERE {raw_where}"
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
" GROUP BY bucket"
' ORDER BY MIN("costMicrodollars")',
*raw_params,
@@ -448,6 +465,16 @@ async def get_platform_cost_dashboard(
_ca(r) for r in total_agg_groups if r.get("trackingType") == "tokens"
)
# Per-user cost-bearing request count: used for per-user avg cost so the
# denominator matches the numerator (cost_usd rows only, per user).
user_cost_bearing_counts: dict[str, int] = {}
for r in by_user_tracking_groups:
if r.get("trackingType") == "cost_usd" and r.get("userId"):
uid = r["userId"]
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
r
)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
@@ -473,6 +500,9 @@ async def get_platform_cost_dashboard(
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
request_count=_ca(r),
cost_bearing_request_count=user_cost_bearing_counts.get(
r.get("userId") or "", 0
),
)
for r in by_user_groups
],

View File

@@ -286,6 +286,7 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups (no cost_usd rows for this user)
[{"userId": "u1"}], # distinct users
[provider_row], # total agg
]
@@ -322,7 +323,9 @@ class TestGetPlatformCostDashboard:
assert len(dashboard.by_user) == 1
assert dashboard.by_user[0].email == "a***@b.com"
assert dashboard.cost_p50_microdollars == 1000
assert dashboard.cost_p75_microdollars == 2000
assert dashboard.cost_p95_microdollars == 4000
assert dashboard.cost_p99_microdollars == 5000
assert len(dashboard.cost_buckets) == 1
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
@@ -351,6 +354,7 @@ class TestGetPlatformCostDashboard:
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[], # by_user_tracking_groups
[{"userId": "u2"}], # distinct users
[provider_row], # total agg
]
@@ -385,7 +389,7 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
@@ -418,7 +422,7 @@ class TestGetPlatformCostDashboard:
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
raw_mock = AsyncMock(side_effect=[[], []])
@@ -440,8 +444,8 @@ class TestGetPlatformCostDashboard:
start=start, provider="openai", user_id="u1"
)
# group_by called 4 times (by_provider, by_user, distinct users, totals)
assert mock_actions.group_by.await_count == 4
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users, totals)
assert mock_actions.group_by.await_count == 5
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})

View File

@@ -58,9 +58,11 @@ function UserTable({ data }: Props) {
{formatTokens(row.total_output_tokens)}
</td>
<td className="px-4 py-3 text-right">
{row.request_count > 0 && row.total_cost_microdollars > 0
{(row.cost_bearing_request_count ?? 0) > 0 &&
row.total_cost_microdollars > 0
? formatMicrodollars(
row.total_cost_microdollars / row.request_count,
row.total_cost_microdollars /
row.cost_bearing_request_count!,
)
: "-"}
</td>

View File

@@ -0,0 +1,19 @@
/**
* Generated by orval v7.13.0 🍺
* Do not edit manually.
* AutoGPT Agent Server
* This server is used to execute agents that are created by the AutoGPT system.
* OpenAPI spec version: 0.1
*/
import type { UserCostSummaryUserId } from "./userCostSummaryUserId";
import type { UserCostSummaryEmail } from "./userCostSummaryEmail";
export interface UserCostSummary {
user_id?: UserCostSummaryUserId;
email?: UserCostSummaryEmail;
total_cost_microdollars: number;
total_input_tokens: number;
total_output_tokens: number;
request_count: number;
cost_bearing_request_count?: number;
}

View File

@@ -15585,7 +15585,8 @@
"type": "integer",
"title": "Total Output Tokens"
},
"request_count": { "type": "integer", "title": "Request Count" }
"request_count": { "type": "integer", "title": "Request Count" },
"cost_bearing_request_count": { "type": "integer", "title": "Cost Bearing Request Count", "default": 0 }
},
"type": "object",
"required": [