fix(backend): address all open review comments on platform cost tracking

- Normalize provider to lowercase at write time; drop LOWER() in filter so
  the (provider, createdAt) index is used without function overhead
- Drop COALESCE(trackingType, metadata->>'tracking_type') fallback — new rows
  always have trackingType set at write time
- Derive total_users from len(by_user_rows) instead of a separate
  COUNT(DISTINCT userId) query (saves one aggregation per dashboard load)
- Add 30-second TTLCache for dashboard endpoint (cachetools, maxsize=256)
- Add backpressure/bounds comment to _pending_log_tasks in platform_cost.py
- Convert f-string logger calls in token_tracking.py to lazy %s formatting
- Add 6 block-level tests for ExaCodeContextBlock and ExaContentsBlock cost
  paths: valid/invalid/zero cost_dollars strings and None cost_dollars
- Update existing tests to match provider-lowercasing and 2-query dashboard
This commit is contained in:
Zamil Majdy
2026-04-07 16:23:10 +07:00
parent 50a8df3d67
commit c9461836c6
6 changed files with 277 additions and 35 deletions

View File

@@ -2,6 +2,7 @@ import logging
from datetime import datetime
from autogpt_libs.auth import get_user_id, requires_admin_user
from cachetools import TTLCache
from fastapi import APIRouter, Query, Security
from pydantic import BaseModel
@@ -15,6 +16,13 @@ from backend.util.models import Pagination
logger = logging.getLogger(__name__)
# Cache dashboard results for 30 seconds per unique filter combination.
# The table is append-only so stale reads are acceptable for analytics.
_DASHBOARD_CACHE_TTL = 30
_dashboard_cache: TTLCache[tuple, PlatformCostDashboard] = TTLCache(
maxsize=256, ttl=_DASHBOARD_CACHE_TTL
)
router = APIRouter(
prefix="/platform-costs",
@@ -41,12 +49,18 @@ async def get_cost_dashboard(
user_id: str | None = Query(None),
):
logger.info("Admin %s fetching platform cost dashboard", admin_user_id)
return await get_platform_cost_dashboard(
cache_key = (start, end, provider, user_id)
cached = _dashboard_cache.get(cache_key)
if cached is not None:
return cached
result = await get_platform_cost_dashboard(
start=start,
end=end,
provider=provider,
user_id=user_id,
)
_dashboard_cache[cache_key] = result
return result
@router.get(

View File

@@ -8,6 +8,7 @@ from autogpt_libs.auth.jwt_utils import get_jwt_payload
from backend.data.platform_cost import PlatformCostDashboard
from . import platform_cost_routes
from .platform_cost_routes import router as platform_cost_router
app = fastapi.FastAPI()
@@ -20,6 +21,8 @@ client = fastapi.testclient.TestClient(app)
def setup_app_admin_auth(mock_jwt_admin):
"""Setup admin auth overrides for all tests in this module"""
app.dependency_overrides[get_jwt_payload] = mock_jwt_admin["get_jwt_payload"]
# Clear TTL cache so each test starts cold.
platform_cost_routes._dashboard_cache.clear()
yield
app.dependency_overrides.clear()

View File

@@ -0,0 +1,210 @@
"""Tests for cost tracking in Exa blocks.
Covers the cost_dollars → provider_cost → merge_stats path for both
ExaContentsBlock and ExaCodeContextBlock.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.blocks.exa._test import TEST_CREDENTIALS, TEST_CREDENTIALS_INPUT
from backend.data.model import NodeExecutionStats
class TestExaCodeContextCostTracking:
"""ExaCodeContextBlock parses cost_dollars (string) and calls merge_stats."""
@pytest.mark.asyncio
async def test_valid_cost_string_is_parsed_and_merged(self):
"""A numeric cost string like '0.005' is merged as provider_cost."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-1",
"query": "test query",
"response": "some code",
"resultsCount": 3,
"costDollars": "0.005",
"searchTime": 1.2,
"outputTokens": 100,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
assert any(k == "cost_dollars" for k, _ in outputs)
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.005)
@pytest.mark.asyncio
async def test_invalid_cost_string_does_not_raise(self):
"""A non-numeric cost_dollars value is swallowed silently."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-2",
"query": "test",
"response": "code",
"resultsCount": 0,
"costDollars": "N/A",
"searchTime": 0.5,
"outputTokens": 0,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
outputs = []
async for key, value in block.run(
block.Input(query="test", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
outputs.append((key, value))
# No merge_stats call because float() raised ValueError
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_string_is_merged(self):
"""'0.0' is a valid cost — should still be tracked."""
from backend.blocks.exa.code_context import ExaCodeContextBlock
block = ExaCodeContextBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
api_response = {
"requestId": "req-3",
"query": "free query",
"response": "result",
"resultsCount": 1,
"costDollars": "0.0",
"searchTime": 0.1,
"outputTokens": 10,
}
with patch("backend.blocks.exa.code_context.Requests") as mock_requests_cls:
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
mock_requests_cls.return_value.post = AsyncMock(return_value=mock_resp)
async for _ in block.run(
block.Input(query="free query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaContentsCostTracking:
"""ExaContentsBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.012)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.012)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.contents import ExaContentsBlock
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
@pytest.mark.asyncio
async def test_zero_cost_dollars_is_merged(self):
"""A total of 0.0 (free tier) should still be merged."""
from backend.blocks.exa.contents import ExaContentsBlock
from backend.blocks.exa.helpers import CostDollars
block = ExaContentsBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.statuses = None
mock_sdk_response.cost_dollars = CostDollars(total=0.0)
with patch("backend.blocks.exa.contents.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.get_contents = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(urls=["https://example.com"], credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)

View File

@@ -95,14 +95,23 @@ async def persist_and_record_usage(
if cache_read_tokens or cache_creation_tokens:
logger.info(
f"{log_prefix} Turn usage: uncached={prompt_tokens}, "
f"cache_read={cache_read_tokens}, cache_create={cache_creation_tokens}, "
f"output={completion_tokens}, total={total_tokens}, cost_usd={cost_usd}"
"%s Turn usage: uncached=%d, cache_read=%d, cache_create=%d,"
" output=%d, total=%d, cost_usd=%s",
log_prefix,
prompt_tokens,
cache_read_tokens,
cache_creation_tokens,
completion_tokens,
total_tokens,
cost_usd,
)
else:
logger.info(
f"{log_prefix} Turn usage: prompt={prompt_tokens}, "
f"completion={completion_tokens}, total={total_tokens}"
"%s Turn usage: prompt=%d, completion=%d, total=%d",
log_prefix,
prompt_tokens,
completion_tokens,
total_tokens,
)
if user_id:
@@ -115,7 +124,7 @@ async def persist_and_record_usage(
cache_creation_tokens=cache_creation_tokens,
)
except Exception as usage_err:
logger.warning(f"{log_prefix} Failed to record token usage: {usage_err}")
logger.warning("%s Failed to record token usage: %s", log_prefix, usage_err)
# Log to PlatformCostLog for admin cost dashboard.
# Include entries where cost_usd is set even if token count is 0

View File

@@ -70,7 +70,9 @@ async def log_platform_cost(entry: PlatformCostEntry) -> None:
entry.node_id,
entry.block_id,
entry.block_name,
entry.provider,
# Normalize to lowercase so the (provider, createdAt) index is always
# used without LOWER() on the read side.
entry.provider.lower(),
entry.credential_id,
entry.cost_microdollars,
entry.input_tokens,
@@ -99,6 +101,12 @@ async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
# Hold strong references to in-flight log tasks to prevent GC.
# Tasks remove themselves on completion via add_done_callback.
#
# NOTE: this set is intentionally unbounded. Under sustained high load or DB
# slowness the set could grow without limit. Adding a bounded asyncio.Semaphore
# would provide back-pressure but is deferred until we observe memory pressure
# in production. The set is small in practice because log inserts are fast
# (sub-millisecond on a healthy DB).
_pending_log_tasks: set["asyncio.Task[None]"] = set()
@@ -199,8 +207,10 @@ def _build_where(
params.append(end)
idx += 1
if provider:
clauses.append(f'LOWER({prefix}"provider") = LOWER(${idx})')
params.append(provider)
# Provider names are normalized to lowercase at write time so a plain
# equality check is sufficient and the (provider, createdAt) index is used.
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
@@ -231,13 +241,12 @@ async def get_platform_cost_dashboard(
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_p, params_p = _build_where(start, end, provider, user_id, "p")
by_provider_rows, user_count_rows, by_user_rows = await asyncio.gather(
by_provider_rows, by_user_rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT
p."provider",
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
AS tracking_type,
p."trackingType" AS tracking_type,
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
@@ -246,21 +255,12 @@ async def get_platform_cost_dashboard(
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider",
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
GROUP BY p."provider", p."trackingType"
ORDER BY total_cost DESC
LIMIT {MAX_PROVIDER_ROWS}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
),
query_raw_with_schema(
f"""
SELECT
@@ -281,7 +281,9 @@ async def get_platform_cost_dashboard(
),
)
total_users = user_count_rows[0]["cnt"] if user_count_rows else 0
# Derive total_users from by_user rows rather than a separate
# COUNT(DISTINCT userId) query — avoids a full-table scan.
total_users = len(by_user_rows)
total_cost = sum(r["total_cost"] for r in by_provider_rows)
total_requests = sum(r["request_count"] for r in by_provider_rows)
@@ -352,8 +354,7 @@ async def get_platform_cost_logs(
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
AS tracking_type,
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,

View File

@@ -49,8 +49,11 @@ class TestBuildWhere:
assert params == [dt]
def test_provider_only(self):
sql, params = _build_where(None, None, "openai", None)
assert 'LOWER("provider") = LOWER($1)' in sql
# Provider names are normalized to lowercase at write time, so the
# filter uses a plain equality check. The input is also lowercased so
# "OpenAI" and "openai" both match stored rows.
sql, params = _build_where(None, None, "OpenAI", None)
assert '"provider" = $1' in sql
assert params == ["openai"]
def test_user_id_only(self):
@@ -61,12 +64,13 @@ class TestBuildWhere:
def test_all_filters(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
end = datetime(2026, 6, 1, tzinfo=timezone.utc)
sql, params = _build_where(start, end, "anthropic", "u1")
sql, params = _build_where(start, end, "Anthropic", "u1")
assert "$1" in sql
assert "$2" in sql
assert "$3" in sql
assert "$4" in sql
assert len(params) == 4
# Provider is lowercased at filter time to match stored lowercase values.
assert params == [start, end, "anthropic", "u1"]
def test_table_alias(self):
@@ -157,7 +161,6 @@ class TestGetPlatformCostDashboard:
"request_count": 3,
}
]
user_count_rows = [{"cnt": 2}]
user_rows = [
{
"user_id": "u1",
@@ -168,12 +171,14 @@ class TestGetPlatformCostDashboard:
"request_count": 3,
}
]
mock_query = AsyncMock(side_effect=[provider_rows, user_count_rows, user_rows])
# Dashboard now runs 2 queries (by_provider + by_user) — total_users is
# derived from len(by_user_rows) instead of a separate COUNT query.
mock_query = AsyncMock(side_effect=[provider_rows, user_rows])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
assert dashboard.total_requests == 3
assert dashboard.total_users == 2
assert dashboard.total_users == 1
assert len(dashboard.by_provider) == 1
assert dashboard.by_provider[0].provider == "openai"
assert dashboard.by_provider[0].tracking_type == "tokens"
@@ -183,7 +188,7 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], [{"cnt": 0}], []])
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
@@ -195,12 +200,12 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], [{"cnt": 0}], []])
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 3
assert mock_query.await_count == 2
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql