mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform-cost): avg stats use unfiltered agg to stay nonzero when tracking_type filtered
When a caller filters the dashboard by tracking_type='tokens', total_agg_groups only contains tokens rows so cost_bearing_requests=0 and avg_cost_microdollars_per_request silently returned 0.0. Symmetrically, filtering by cost_usd gave zero token averages. Add a parallel total_agg_no_tracking_type_groups query (using where_no_tracking_type, mirroring the fix already applied to by_user_tracking_groups) and derive avg_cost_total, avg_input_total, avg_output_total, cost_bearing_requests, and token_bearing_requests from that unfiltered aggregate. The displayed grand totals (total_cost, total_requests, total_input_tokens) remain scoped to the active filter. Also adds test_global_avg_cost_nonzero_when_filtering_by_tokens to cover this case.
This commit is contained in:
@@ -333,13 +333,14 @@ async def get_platform_cost_dashboard(
|
||||
|
||||
raw_where = " AND ".join(raw_where_clauses)
|
||||
|
||||
# Run all seven aggregation queries in parallel.
|
||||
# Run all eight aggregation queries in parallel.
|
||||
(
|
||||
by_provider_groups,
|
||||
by_user_groups,
|
||||
by_user_tracking_groups,
|
||||
total_user_groups,
|
||||
total_agg_groups,
|
||||
total_agg_no_tracking_type_groups,
|
||||
percentile_rows,
|
||||
bucket_rows,
|
||||
) = await asyncio.gather(
|
||||
@@ -373,8 +374,8 @@ async def get_platform_cost_dashboard(
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate: group by (provider, trackingType) so we can
|
||||
# distinguish cost-bearing rows for per-request averages.
|
||||
# Total aggregate (filtered): group by (provider, trackingType) so we can
|
||||
# compute grand totals for cost/tokens within the active filter window.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where,
|
||||
@@ -385,6 +386,20 @@ async def get_platform_cost_dashboard(
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Total aggregate (no tracking_type filter): used to compute
|
||||
# cost_bearing_requests and token_bearing_requests denominators so
|
||||
# global avg stats remain meaningful when the caller filters the main
|
||||
# view by a specific tracking_type (e.g. 'tokens').
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType"],
|
||||
where=where_no_tracking_type,
|
||||
sum={
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
},
|
||||
count=True,
|
||||
),
|
||||
# Percentile distribution of cost per request (respects all filters).
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
@@ -465,14 +480,37 @@ async def get_platform_cost_dashboard(
|
||||
CostBucket(bucket=r["bucket"], count=int(r["count"])) for r in bucket_rows
|
||||
]
|
||||
|
||||
# Cost-bearing request count: only rows where trackingType == "cost_usd".
|
||||
# Avg-stat numerators and denominators are derived from the unfiltered
|
||||
# aggregate so they remain meaningful when the caller filters by a specific
|
||||
# tracking_type. Example: filtering by 'tokens' excludes cost_usd rows from
|
||||
# total_agg_groups, so avg_cost would always be 0 if we used that; using
|
||||
# total_agg_no_tracking_type_groups gives the correct cost_usd total/count.
|
||||
avg_cost_total = sum(
|
||||
_si(r, "costMicrodollars")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
cost_bearing_requests = sum(
|
||||
_ca(r) for r in total_agg_groups if r.get("trackingType") == "cost_usd"
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "cost_usd"
|
||||
)
|
||||
avg_input_total = sum(
|
||||
_si(r, "inputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
avg_output_total = sum(
|
||||
_si(r, "outputTokens")
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
# Token-bearing request count: only rows where trackingType == "tokens".
|
||||
# Token averages must use this denominator; cost_usd rows do not carry tokens.
|
||||
token_bearing_requests = sum(
|
||||
_ca(r) for r in total_agg_groups if r.get("trackingType") == "tokens"
|
||||
_ca(r)
|
||||
for r in total_agg_no_tracking_type_groups
|
||||
if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Per-user cost-bearing request count: used for per-user avg cost so the
|
||||
@@ -522,17 +560,17 @@ async def get_platform_cost_dashboard(
|
||||
total_input_tokens=total_input_tokens,
|
||||
total_output_tokens=total_output_tokens,
|
||||
avg_input_tokens_per_request=(
|
||||
total_input_tokens / token_bearing_requests
|
||||
avg_input_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_output_tokens_per_request=(
|
||||
total_output_tokens / token_bearing_requests
|
||||
avg_output_total / token_bearing_requests
|
||||
if token_bearing_requests > 0
|
||||
else 0.0
|
||||
),
|
||||
avg_cost_microdollars_per_request=(
|
||||
total_cost / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
|
||||
avg_cost_total / cost_bearing_requests if cost_bearing_requests > 0 else 0.0
|
||||
),
|
||||
cost_p50_microdollars=cost_p50,
|
||||
cost_p75_microdollars=cost_p75,
|
||||
|
||||
@@ -288,7 +288,8 @@ class TestGetPlatformCostDashboard:
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups (no cost_usd rows for this user)
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (filtered)
|
||||
[provider_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[mock_user])
|
||||
@@ -364,7 +365,8 @@ class TestGetPlatformCostDashboard:
|
||||
user_tracking_tokens_row,
|
||||
], # by_user_tracking
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[total_row], # total agg
|
||||
[total_row], # total agg (filtered)
|
||||
[total_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
@@ -392,6 +394,57 @@ class TestGetPlatformCostDashboard:
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].cost_bearing_request_count == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_global_avg_cost_nonzero_when_filtering_by_tokens(self):
|
||||
"""When filtering by tracking_type='tokens', avg_cost_microdollars_per_request
|
||||
must still reflect cost_usd rows from total_agg_no_tracking_type_groups,
|
||||
not the filtered total_agg_groups which only has tokens rows."""
|
||||
# filtered total_agg only has tokens rows (zero cost)
|
||||
tokens_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="tokens", cost=0, count=5
|
||||
)
|
||||
# unfiltered total_agg has both rows (cost_usd carries the actual cost)
|
||||
cost_usd_row = _make_group_by_row(
|
||||
provider="openai", tracking_type="cost_usd", cost=10_000, count=4
|
||||
)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[tokens_row], # by_provider
|
||||
[{"_sum": {}, "_count": {"_all": 5}, "userId": "u1"}], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[tokens_row], # total agg (filtered — tokens only)
|
||||
[tokens_row, cost_usd_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=[[], []],
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# avg_cost_microdollars_per_request must be non-zero: cost_usd row
|
||||
# (10_000 microdollars, 4 requests) is present in the unfiltered agg.
|
||||
assert dashboard.avg_cost_microdollars_per_request == pytest.approx(10_000 / 4)
|
||||
# avg token stats use token_bearing_requests from unfiltered agg (5)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 5)
|
||||
assert dashboard.avg_output_tokens_per_request == pytest.approx(500 / 5)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_tokens_aggregated_not_hardcoded(self):
|
||||
"""cache_read_tokens and cache_creation_tokens must be read from the
|
||||
@@ -415,7 +468,8 @@ class TestGetPlatformCostDashboard:
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
[provider_row], # total agg (filtered)
|
||||
[provider_row], # total agg (no tracking_type filter)
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
@@ -448,7 +502,7 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -481,7 +535,7 @@ class TestGetPlatformCostDashboard:
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
raw_mock = AsyncMock(side_effect=[[], []])
|
||||
@@ -503,8 +557,9 @@ class TestGetPlatformCostDashboard:
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
|
||||
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 5
|
||||
# group_by called 6 times (by_provider, by_user, by_user_tracking, distinct users,
|
||||
# total agg filtered, total agg no-tracking-type)
|
||||
assert mock_actions.group_by.await_count == 6
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
@@ -520,7 +575,7 @@ class TestGetPlatformCostDashboard:
|
||||
"""by_user_tracking_groups must NOT apply the tracking_type filter so that
|
||||
cost_usd rows are always included even when the caller filters by 'tokens'."""
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -541,7 +596,7 @@ class TestGetPlatformCostDashboard:
|
||||
await get_platform_cost_dashboard(tracking_type="tokens")
|
||||
|
||||
# Call index 2 is by_user_tracking_groups (0=by_provider, 1=by_user,
|
||||
# 2=by_user_tracking, 3=distinct_users, 4=total_agg).
|
||||
# 2=by_user_tracking, 3=distinct_users, 4=total_agg, 5=total_agg_no_tt).
|
||||
tracking_call_where = mock_actions.group_by.call_args_list[2][1]["where"]
|
||||
# The main filter applies trackingType; by_user_tracking must NOT.
|
||||
assert "trackingType" not in tracking_call_where
|
||||
|
||||
Reference in New Issue
Block a user