mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(platform-cost): fix per-user avg cost denominator, NULL bucket, tracking_type filter gap
- Add `cost_bearing_request_count` to `UserCostSummary` via a new group-by-(userId,trackingType) query; `UserTable` now divides by this count instead of the mixed `request_count`, eliminating denominator dilution for users with both tokens and cost_usd rows - Guard histogram CASE against NULL costMicrodollars (NULL < N → unknown falls to ELSE '$10+'); add `AND "costMicrodollars" IS NOT NULL` to the histogram WHERE so NULL rows are excluded instead of bucketed - Respect the `tracking_type` dashboard filter in raw SQL percentile and bucket queries; previously the filter was hardcoded to 'cost_usd' even when the caller passed tracking_type='tokens', making those queries return inconsistent data relative to the ORM queries - Add p75 and p99 assertions to test_returns_dashboard_with_data - Update openapi.json and generated TS model for new field
This commit is contained in:
@@ -3,12 +3,11 @@ import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
@@ -143,6 +142,7 @@ class UserCostSummary(BaseModel):
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
request_count: int
|
||||
cost_bearing_request_count: int = 0
|
||||
|
||||
|
||||
class CostLogRow(BaseModel):
|
||||
@@ -286,7 +286,7 @@ async def get_platform_cost_dashboard(
|
||||
# queries so they honour all active dashboard filters, not just start date.
|
||||
raw_params: list = [start]
|
||||
raw_where_clauses = [
|
||||
'"trackingType" = \'cost_usd\'',
|
||||
"\"trackingType\" = 'cost_usd'",
|
||||
'"createdAt" >= $1',
|
||||
]
|
||||
param_idx = 2 # $1 is already start
|
||||
@@ -316,12 +316,20 @@ async def get_platform_cost_dashboard(
|
||||
raw_params.append(block_name)
|
||||
param_idx += 1
|
||||
|
||||
# If the caller supplied a specific tracking_type filter, replace the
|
||||
# hardcoded cost_usd clause so the percentile/bucket queries respect it.
|
||||
if tracking_type is not None:
|
||||
raw_where_clauses[0] = f'"trackingType" = ${param_idx}'
|
||||
raw_params.append(tracking_type)
|
||||
param_idx += 1
|
||||
|
||||
raw_where = " AND ".join(raw_where_clauses)
|
||||
|
||||
# Run all six aggregation queries in parallel.
|
||||
# Run all seven aggregation queries in parallel.
|
||||
(
|
||||
by_provider_groups,
|
||||
by_user_groups,
|
||||
by_user_tracking_groups,
|
||||
total_user_groups,
|
||||
total_agg_groups,
|
||||
percentile_rows,
|
||||
@@ -342,6 +350,13 @@ async def get_platform_cost_dashboard(
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
# Per-user cost-bearing request count: group by (userId, trackingType)
|
||||
# so we can compute the correct denominator for per-user avg cost.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId", "trackingType"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
@@ -376,6 +391,8 @@ async def get_platform_cost_dashboard(
|
||||
*raw_params,
|
||||
),
|
||||
# Histogram buckets for cost distribution (respects all filters).
|
||||
# NULL costMicrodollars is excluded explicitly to prevent such rows
|
||||
# from falling through all WHEN clauses into the ELSE '$10+' bucket.
|
||||
query_raw_with_schema(
|
||||
"SELECT"
|
||||
" CASE"
|
||||
@@ -393,7 +410,7 @@ async def get_platform_cost_dashboard(
|
||||
" END as bucket,"
|
||||
" COUNT(*) as count"
|
||||
' FROM {schema_prefix}"PlatformCostLog"'
|
||||
f" WHERE {raw_where}"
|
||||
f' WHERE {raw_where} AND "costMicrodollars" IS NOT NULL'
|
||||
" GROUP BY bucket"
|
||||
' ORDER BY MIN("costMicrodollars")',
|
||||
*raw_params,
|
||||
@@ -448,6 +465,16 @@ async def get_platform_cost_dashboard(
|
||||
_ca(r) for r in total_agg_groups if r.get("trackingType") == "tokens"
|
||||
)
|
||||
|
||||
# Per-user cost-bearing request count: used for per-user avg cost so the
|
||||
# denominator matches the numerator (cost_usd rows only, per user).
|
||||
user_cost_bearing_counts: dict[str, int] = {}
|
||||
for r in by_user_tracking_groups:
|
||||
if r.get("trackingType") == "cost_usd" and r.get("userId"):
|
||||
uid = r["userId"]
|
||||
user_cost_bearing_counts[uid] = user_cost_bearing_counts.get(uid, 0) + _ca(
|
||||
r
|
||||
)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
@@ -473,6 +500,9 @@ async def get_platform_cost_dashboard(
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
request_count=_ca(r),
|
||||
cost_bearing_request_count=user_cost_bearing_counts.get(
|
||||
r.get("userId") or "", 0
|
||||
),
|
||||
)
|
||||
for r in by_user_groups
|
||||
],
|
||||
|
||||
@@ -286,6 +286,7 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups (no cost_usd rows for this user)
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
@@ -322,7 +323,9 @@ class TestGetPlatformCostDashboard:
|
||||
assert len(dashboard.by_user) == 1
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
assert dashboard.cost_p50_microdollars == 1000
|
||||
assert dashboard.cost_p75_microdollars == 2000
|
||||
assert dashboard.cost_p95_microdollars == 4000
|
||||
assert dashboard.cost_p99_microdollars == 5000
|
||||
assert len(dashboard.cost_buckets) == 1
|
||||
# Token averages must use token_bearing_requests (3) not cost_bearing (0)
|
||||
assert dashboard.avg_input_tokens_per_request == pytest.approx(1000 / 3)
|
||||
@@ -351,6 +354,7 @@ class TestGetPlatformCostDashboard:
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[], # by_user_tracking_groups
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
@@ -385,7 +389,7 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
@@ -418,7 +422,7 @@ class TestGetPlatformCostDashboard:
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
raw_mock = AsyncMock(side_effect=[[], []])
|
||||
@@ -440,8 +444,8 @@ class TestGetPlatformCostDashboard:
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
|
||||
# group_by called 4 times (by_provider, by_user, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 4
|
||||
# group_by called 5 times (by_provider, by_user, by_user_tracking, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 5
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
|
||||
@@ -58,9 +58,11 @@ function UserTable({ data }: Props) {
|
||||
{formatTokens(row.total_output_tokens)}
|
||||
</td>
|
||||
<td className="px-4 py-3 text-right">
|
||||
{row.request_count > 0 && row.total_cost_microdollars > 0
|
||||
{(row.cost_bearing_request_count ?? 0) > 0 &&
|
||||
row.total_cost_microdollars > 0
|
||||
? formatMicrodollars(
|
||||
row.total_cost_microdollars / row.request_count,
|
||||
row.total_cost_microdollars /
|
||||
row.cost_bearing_request_count!,
|
||||
)
|
||||
: "-"}
|
||||
</td>
|
||||
|
||||
19
autogpt_platform/frontend/src/app/api/__generated__/models/userCostSummary.ts
generated
Normal file
19
autogpt_platform/frontend/src/app/api/__generated__/models/userCostSummary.ts
generated
Normal file
@@ -0,0 +1,19 @@
|
||||
/**
|
||||
* Generated by orval v7.13.0 🍺
|
||||
* Do not edit manually.
|
||||
* AutoGPT Agent Server
|
||||
* This server is used to execute agents that are created by the AutoGPT system.
|
||||
* OpenAPI spec version: 0.1
|
||||
*/
|
||||
import type { UserCostSummaryUserId } from "./userCostSummaryUserId";
|
||||
import type { UserCostSummaryEmail } from "./userCostSummaryEmail";
|
||||
|
||||
export interface UserCostSummary {
|
||||
user_id?: UserCostSummaryUserId;
|
||||
email?: UserCostSummaryEmail;
|
||||
total_cost_microdollars: number;
|
||||
total_input_tokens: number;
|
||||
total_output_tokens: number;
|
||||
request_count: number;
|
||||
cost_bearing_request_count?: number;
|
||||
}
|
||||
@@ -15585,7 +15585,8 @@
|
||||
"type": "integer",
|
||||
"title": "Total Output Tokens"
|
||||
},
|
||||
"request_count": { "type": "integer", "title": "Request Count" }
|
||||
"request_count": { "type": "integer", "title": "Request Count" },
|
||||
"cost_bearing_request_count": { "type": "integer", "title": "Cost Bearing Request Count", "default": 0 }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
||||
Reference in New Issue
Block a user