From b319c26cab1b4b293ae73acdd1bd6fb0f5299eec Mon Sep 17 00:00:00 2001 From: Zamil Majdy Date: Fri, 10 Apr 2026 23:14:43 +0700 Subject: [PATCH] feat(platform/admin): per-model cost breakdown, cache token tracking, OrchestratorBlock cost fix (#12726) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why The platform cost tracking system had several gaps that made the admin dashboard less accurate and harder to reason about: **Q: Do we have per-model granularity on the provider page?** The `model` column was stored in `PlatformCostLog` but the SQL aggregation grouped only by `(provider, tracking_type)`, so all models for a given provider collapsed into one row. Now grouped by `(provider, tracking_type, model)` — each model gets its own row. **Q: Why does Anthropic show `per_run` for OrchestratorBlock?** Bug: `OrchestratorBlock._call_llm()` was building `NodeExecutionStats` with only `input_token_count` and `output_token_count` — it dropped `resp.provider_cost` entirely. For OpenRouter calls this silently discarded the `cost_usd`. For the SDK (autopilot) path, `ResultMessage.total_cost_usd` was never read. When `provider_cost` is None and token counts are 0 (e.g. SDK error path), `resolve_tracking` falls through to `per_run`. Fixed by propagating all cost/cache fields. **Q: Why can't we get `cost_usd` for Anthropic direct API calls?** The Anthropic Messages API does not return a dollar amount — only token counts. OpenRouter returns cost via response headers, so it uses `cost_usd` directly. The Claude Agent SDK *does* compute `total_cost_usd` internally, so SDK-mode OrchestratorBlock runs now get `cost_usd` tracking. For direct Anthropic LLM blocks the estimate uses per-token rates (see cache section below). **Q: What about labeling by source (autopilot vs block)?** Already tracked: `block_name` stores `copilot:SDK`, `copilot:Baseline`, or the actual block name. Visible in the raw logs table. Not added to the provider group-by (would explode row count); use the logs table filter instead. **Q: Is there double-counting between `tokens`, `per_run`, and `cost_usd`?** No. `resolve_tracking()` uses a strict preference hierarchy — exactly one tracking type per execution: `cost_usd` > `tokens` > provider heuristics > `per_run`. A single execution produces exactly one `PlatformCostLog` row. **Q: Should we track Anthropic prompt cache tokens (PR #12725)?** Yes — PR #12725 adds `cache_control` markers to Anthropic API calls, which causes the API to return `cache_read_input_tokens` and `cache_creation_input_tokens` alongside regular `input_tokens`. These have different billing rates: - Cache reads: **10%** of base input rate (much cheaper) - Cache writes: **125%** of base input rate (slightly more expensive, one-time) - Uncached input: **100%** of base rate Without tracking them separately, a flat-rate estimate on `total_input_tokens` would be wrong in both directions. ## What - **Per-model provider table**: SQL now groups by `(provider, tracking_type, model)`. `ProviderCostSummary` and the frontend `ProviderTable` show a model column. - **Cache token columns**: New `cacheReadTokens` and `cacheCreationTokens` columns in `PlatformCostLog` with matching migration. - **LLM block cache tracking**: `LLMResponse` captures `cache_read_input_tokens` / `cache_creation_input_tokens` from Anthropic responses. `NodeExecutionStats` gains `cache_read_token_count` / `cache_creation_token_count`. Both propagate to `PlatformCostEntry` and the DB. - **Copilot path**: `token_tracking.persist_and_record_usage` now writes cache tokens as dedicated `PlatformCostEntry` fields (was metadata-only). - **OrchestratorBlock bug fix**: `_call_llm()` now includes `resp.provider_cost`, `resp.cache_read_tokens`, `resp.cache_creation_tokens` in the stats merge. SDK path captures `ResultMessage.total_cost_usd` as `provider_cost`. - **Accurate cost estimation**: `estimateCostForRow` uses token-type-specific rates for `tokens` rows (uncached=100%, reads=10%, writes=125% of configured base rate). ## How `resolve_tracking` priority is unchanged. For Anthropic LLM blocks the tracking type remains `tokens` (Anthropic API returns no dollar amount). For OrchestratorBlock in SDK/autopilot mode it now correctly uses `cost_usd` because the Claude Agent SDK computes and returns `total_cost_usd`. For OpenRouter through OrchestratorBlock it now correctly uses `cost_usd` (was silently dropped before). ## Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] `ProviderCostSummary` SQL updated - [x] Cache token fields present in `PlatformCostEntry` and `PlatformCostLogCreateInput` - [x] Prisma client regenerated — all type checks pass - [x] Frontend `helpers.test.ts` updated for new `rateKey` format - [x] Pre-commit hooks pass (Black, Ruff, isort, tsc, Prisma generate) --- .../analytics/queries/platform_cost_log.sql | 100 ++++ .../backend/backend/blocks/llm.py | 17 +- .../backend/backend/blocks/test/test_llm.py | 56 ++- .../backend/backend/data/platform_cost.py | 467 +++++++++--------- .../backend/data/platform_cost_test.py | 407 ++++++++++----- 5 files changed, 697 insertions(+), 350 deletions(-) create mode 100644 autogpt_platform/analytics/queries/platform_cost_log.sql diff --git a/autogpt_platform/analytics/queries/platform_cost_log.sql b/autogpt_platform/analytics/queries/platform_cost_log.sql new file mode 100644 index 0000000000..b3e33d7515 --- /dev/null +++ b/autogpt_platform/analytics/queries/platform_cost_log.sql @@ -0,0 +1,100 @@ +-- ============================================================= +-- View: analytics.platform_cost_log +-- Looker source alias: ds115 | Charts: 0 +-- ============================================================= +-- DESCRIPTION +-- One row per platform cost log entry (last 90 days). +-- Tracks real API spend at the call level: provider, model, +-- token counts (including Anthropic cache tokens), cost in +-- microdollars, and the block/execution that incurred the cost. +-- Joins the User table to provide email for per-user breakdowns. +-- +-- SOURCE TABLES +-- platform.PlatformCostLog — Per-call cost records +-- platform.User — User email +-- +-- OUTPUT COLUMNS +-- id TEXT Log entry UUID +-- createdAt TIMESTAMPTZ When the cost was recorded +-- userId TEXT User who incurred the cost (nullable) +-- email TEXT User email (nullable) +-- graphExecId TEXT Graph execution UUID (nullable) +-- nodeExecId TEXT Node execution UUID (nullable) +-- blockName TEXT Block that made the API call (nullable) +-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic') +-- model TEXT Model name (nullable) +-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc. +-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD) +-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000) +-- inputTokens INT Prompt/input tokens (nullable) +-- outputTokens INT Completion/output tokens (nullable) +-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable) +-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable) +-- totalTokens INT inputTokens + outputTokens (nullable if either is null) +-- duration FLOAT API call duration in seconds (nullable) +-- +-- WINDOW +-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days) +-- +-- EXAMPLE QUERIES +-- -- Total spend by provider (last 90 days) +-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY total_usd DESC; +-- +-- -- Spend by model +-- SELECT provider, model, SUM("costUsd") AS total_usd, +-- SUM("inputTokens") AS input_tokens, +-- SUM("outputTokens") AS output_tokens +-- FROM analytics.platform_cost_log +-- WHERE model IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC; +-- +-- -- Top 20 users by spend +-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- WHERE "userId" IS NOT NULL +-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20; +-- +-- -- Daily spend trend +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("costUsd") AS daily_usd, +-- COUNT(*) AS calls +-- FROM analytics.platform_cost_log +-- GROUP BY 1 ORDER BY 1; +-- +-- -- Cache hit rate for Anthropic (cache reads vs total reads) +-- SELECT DATE_TRUNC('day', "createdAt") AS day, +-- SUM("cacheReadTokens")::float / +-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate +-- FROM analytics.platform_cost_log +-- WHERE provider = 'anthropic' +-- GROUP BY 1 ORDER BY 1; +-- ============================================================= + +SELECT + p."id" AS id, + p."createdAt" AS createdAt, + p."userId" AS userId, + u."email" AS email, + p."graphExecId" AS graphExecId, + p."nodeExecId" AS nodeExecId, + p."blockName" AS blockName, + p."provider" AS provider, + p."model" AS model, + p."trackingType" AS trackingType, + p."costMicrodollars" AS costMicrodollars, + p."costMicrodollars"::float / 1000000.0 AS costUsd, + p."inputTokens" AS inputTokens, + p."outputTokens" AS outputTokens, + p."cacheReadTokens" AS cacheReadTokens, + p."cacheCreationTokens" AS cacheCreationTokens, + CASE + WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL + THEN p."inputTokens" + p."outputTokens" + ELSE NULL + END AS totalTokens, + p."duration" AS duration +FROM platform."PlatformCostLog" p +LEFT JOIN platform."User" u ON u."id" = p."userId" +WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days' diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 7c9bd53e75..52e32feb13 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -887,6 +887,21 @@ async def llm_call( provider = llm_model.metadata.provider context_window = llm_model.context_window + # Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key + # is configured, route direct-Anthropic models through OpenRouter instead. This + # gives us the x-total-cost header for free, so provider_cost is always populated + # without manual token-rate arithmetic. + or_key = settings.secrets.open_router_api_key + or_model_id: str | None = None + if provider == "anthropic" and or_key: + provider = "open_router" + credentials = APIKeyCredentials( + provider=ProviderName.OPEN_ROUTER, + title="OpenRouter (auto)", + api_key=SecretStr(or_key), + ) + or_model_id = f"anthropic/{llm_model.value}" + if compress_prompt_to_fit: result = await compress_context( messages=prompt, @@ -1134,7 +1149,7 @@ async def llm_call( "HTTP-Referer": "https://agpt.co", "X-Title": "AutoGPT", }, - model=llm_model.value, + model=or_model_id or llm_model.value, messages=prompt, # type: ignore max_tokens=max_tokens, tools=tools_param, # type: ignore diff --git a/autogpt_platform/backend/backend/blocks/test/test_llm.py b/autogpt_platform/backend/backend/blocks/test/test_llm.py index 9f7e41fc0d..e8eea20040 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_llm.py +++ b/autogpt_platform/backend/backend/blocks/test/test_llm.py @@ -77,7 +77,11 @@ class TestLLMStatsTracking: mock_response.usage = mock_usage mock_response.stop_reason = "end_turn" - with patch("anthropic.AsyncAnthropic") as mock_anthropic: + with ( + patch("anthropic.AsyncAnthropic") as mock_anthropic, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "" mock_client = AsyncMock() mock_anthropic.return_value = mock_client mock_client.messages.create = AsyncMock(return_value=mock_response) @@ -96,6 +100,56 @@ class TestLLMStatsTracking: assert response.cache_creation_tokens == 50 assert response.response == "Test anthropic response" + @pytest.mark.asyncio + async def test_anthropic_routes_through_openrouter_when_key_present(self): + """When open_router_api_key is set, Anthropic models route via OpenRouter.""" + from pydantic import SecretStr + + import backend.blocks.llm as llm + from backend.data.model import APIKeyCredentials + + anthropic_creds = APIKeyCredentials( + id="test-anthropic-id", + provider="anthropic", + api_key=SecretStr("mock-anthropic-key"), + title="Mock Anthropic key", + ) + + mock_choice = MagicMock() + mock_choice.message.content = "routed response" + mock_choice.message.tool_calls = None + + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = mock_usage + + mock_create = AsyncMock(return_value=mock_response) + + with ( + patch("openai.AsyncOpenAI") as mock_openai, + patch("backend.blocks.llm.settings") as mock_settings, + ): + mock_settings.secrets.open_router_api_key = "sk-or-test-key" + mock_client = MagicMock() + mock_openai.return_value = mock_client + mock_client.chat.completions.create = mock_create + + await llm.llm_call( + credentials=anthropic_creds, + llm_model=llm.LlmModel.CLAUDE_3_HAIKU, + prompt=[{"role": "user", "content": "Hello"}], + max_tokens=100, + ) + + # Verify OpenAI client was used (not Anthropic SDK) and model was prefixed + mock_openai.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307" + @pytest.mark.asyncio async def test_ai_structured_response_block_tracks_stats(self): """Test that AIStructuredResponseGeneratorBlock correctly tracks stats.""" diff --git a/autogpt_platform/backend/backend/data/platform_cost.py b/autogpt_platform/backend/backend/data/platform_cost.py index b44bb37910..17915e115c 100644 --- a/autogpt_platform/backend/backend/data/platform_cost.py +++ b/autogpt_platform/backend/backend/data/platform_cost.py @@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone from typing import Any from prisma.models import PlatformCostLog as PrismaLog -from prisma.types import PlatformCostLogCreateInput +from prisma.models import User as PrismaUser +from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput from pydantic import BaseModel -from backend.data.db import query_raw_with_schema from backend.util.cache import cached from backend.util.json import SafeJson @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) MICRODOLLARS_PER_USD = 1_000_000 -# Dashboard query limits — keep in sync with the SQL queries below +# Dashboard query limits MAX_PROVIDER_ROWS = 500 MAX_USER_ROWS = 100 @@ -169,53 +169,61 @@ class PlatformCostDashboard(BaseModel): total_users: int -def _build_where( +def _si(row: dict, field: str) -> int: + """Extract an integer from a Prisma group_by _sum dict. + + Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int. + """ + return int((row.get("_sum") or {}).get(field) or 0) + + +def _sf(row: dict, field: str) -> float: + """Extract a float from a Prisma group_by _sum dict.""" + return float((row.get("_sum") or {}).get(field) or 0.0) + + +def _ca(row: dict) -> int: + """Extract _count._all from a Prisma group_by row.""" + c = row.get("_count") or {} + return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0) + + +def _build_prisma_where( start: datetime | None, end: datetime | None, provider: str | None, user_id: str | None, - table_alias: str = "", model: str | None = None, block_name: str | None = None, tracking_type: str | None = None, -) -> tuple[str, list[Any]]: - prefix = f"{table_alias}." if table_alias else "" - clauses: list[str] = [] - params: list[Any] = [] - idx = 1 +) -> PlatformCostLogWhereInput: + """Build a Prisma WhereInput for PlatformCostLog filters.""" + where: PlatformCostLogWhereInput = {} + + if start and end: + where["createdAt"] = {"gte": start, "lte": end} + elif start: + where["createdAt"] = {"gte": start} + elif end: + where["createdAt"] = {"lte": end} - if start: - clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz') - params.append(start) - idx += 1 - if end: - clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz') - params.append(end) - idx += 1 if provider: - # Provider names are normalized to lowercase at write time so a plain - # equality check is sufficient and the (provider, createdAt) index is used. - clauses.append(f'{prefix}"provider" = ${idx}') - params.append(provider.lower()) - idx += 1 - if user_id: - clauses.append(f'{prefix}"userId" = ${idx}') - params.append(user_id) - idx += 1 - if model: - clauses.append(f'{prefix}"model" = ${idx}') - params.append(model) - idx += 1 - if block_name: - clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})') - params.append(block_name) - idx += 1 - if tracking_type: - clauses.append(f'{prefix}"trackingType" = ${idx}') - params.append(tracking_type) - idx += 1 + where["provider"] = provider.lower() - return (" AND ".join(clauses) if clauses else "TRUE", params) + if user_id: + where["userId"] = user_id + + if model: + where["model"] = model + + if block_name: + # Case-insensitive match — mirrors the original LOWER() SQL filter. + where["blockName"] = {"equals": block_name, "mode": "insensitive"} + + if tracking_type: + where["trackingType"] = tracking_type + + return where @cached(ttl_seconds=30) @@ -241,110 +249,107 @@ async def get_platform_cost_dashboard( """ if start is None: start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) - where_p, params_p = _build_where( - start, end, provider, user_id, "p", model, block_name, tracking_type + + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type ) - by_provider_rows, by_user_rows, total_user_rows, total_agg_rows = ( + sum_fields = { + "costMicrodollars": True, + "inputTokens": True, + "outputTokens": True, + "cacheReadTokens": True, + "cacheCreationTokens": True, + "duration": True, + "trackingAmount": True, + } + + # Run all four aggregation queries in parallel. + by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = ( await asyncio.gather( - query_raw_with_schema( - f""" - SELECT - p."provider", - p."trackingType" AS tracking_type, - p."model", - COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost, - COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens, - COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens, - COALESCE(SUM(p."cacheReadTokens"), 0)::bigint AS total_cache_read_tokens, - COALESCE(SUM(p."cacheCreationTokens"), 0)::bigint AS total_cache_creation_tokens, - COALESCE(SUM(p."duration"), 0)::float AS total_duration, - COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount, - COUNT(*)::bigint AS request_count - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_p} - GROUP BY p."provider", p."trackingType", p."model" - ORDER BY total_cost DESC - LIMIT {MAX_PROVIDER_ROWS} - """, - *params_p, + # (provider, trackingType, model) aggregation — no ORDER BY in ORM; + # sort by total cost descending in Python after fetch. + PrismaLog.prisma().group_by( + by=["provider", "trackingType", "model"], + where=where, + sum=sum_fields, + count=True, ), - query_raw_with_schema( - f""" - SELECT - p."userId" AS user_id, - u."email", - COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost, - COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens, - COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens, - COUNT(*)::bigint AS request_count - FROM {{schema_prefix}}"PlatformCostLog" p - LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId" - WHERE {where_p} - GROUP BY p."userId", u."email" - ORDER BY total_cost DESC - LIMIT {MAX_USER_ROWS} - """, - *params_p, + # userId aggregation — emails fetched separately below. + PrismaLog.prisma().group_by( + by=["userId"], + where=where, + sum=sum_fields, + count=True, ), - query_raw_with_schema( - f""" - SELECT COUNT(DISTINCT p."userId")::bigint AS cnt - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_p} - """, - *params_p, + # Distinct user count: group by userId, count groups. + PrismaLog.prisma().group_by( + by=["userId"], + where=where, + count=True, ), - # Separate aggregate query so dashboard totals are never derived - # from the capped by_provider_rows list. With model-level grouping, - # MAX_PROVIDER_ROWS is hit more easily; summing the capped rows - # would silently undercount once >500 (provider, type, model) exist. - query_raw_with_schema( - f""" - SELECT - COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost, - COUNT(*)::bigint AS request_count - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_p} - """, - *params_p, + # Total aggregate: group by provider (no limit) to sum across all + # matching rows. Summed in Python to get grand totals. + PrismaLog.prisma().group_by( + by=["provider"], + where=where, + sum={"costMicrodollars": True}, + count=True, ), ) ) - # Use the exact COUNT(DISTINCT userId) so total_users is not capped at - # MAX_USER_ROWS (which would silently report 100 for >100 active users). - total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0 - total_cost = int(total_agg_rows[0]["total_cost"]) if total_agg_rows else 0 - total_requests = int(total_agg_rows[0]["request_count"]) if total_agg_rows else 0 + # Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS. + by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True) + by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS] + + # Sort by_user by total cost descending and cap at MAX_USER_ROWS. + by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True) + by_user_groups = by_user_groups[:MAX_USER_ROWS] + + # Batch-fetch emails for the users in by_user. + user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None] + email_by_user_id: dict[str, str | None] = {} + if user_ids: + users = await PrismaUser.prisma().find_many( + where={"id": {"in": user_ids}}, + ) + email_by_user_id = {u.id: u.email for u in users} + + # Total distinct users — exclude the NULL-userId group (deleted users). + total_users = len([g for g in total_user_groups if g.get("userId") is not None]) + + # Grand totals — sum across all provider groups (no LIMIT applied above). + total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups) + total_requests = sum(_ca(r) for r in total_agg_groups) return PlatformCostDashboard( by_provider=[ ProviderCostSummary( provider=r["provider"], - tracking_type=r.get("tracking_type"), + tracking_type=r.get("trackingType"), model=r.get("model"), - total_cost_microdollars=r["total_cost"], - total_input_tokens=r["total_input_tokens"], - total_output_tokens=r["total_output_tokens"], - total_cache_read_tokens=r.get("total_cache_read_tokens", 0), - total_cache_creation_tokens=r.get("total_cache_creation_tokens", 0), - total_duration_seconds=r.get("total_duration", 0.0), - total_tracking_amount=r.get("total_tracking_amount", 0.0), - request_count=r["request_count"], + total_cost_microdollars=_si(r, "costMicrodollars"), + total_input_tokens=_si(r, "inputTokens"), + total_output_tokens=_si(r, "outputTokens"), + total_cache_read_tokens=_si(r, "cacheReadTokens"), + total_cache_creation_tokens=_si(r, "cacheCreationTokens"), + total_duration_seconds=_sf(r, "duration"), + total_tracking_amount=_sf(r, "trackingAmount"), + request_count=_ca(r), ) - for r in by_provider_rows + for r in by_provider_groups ], by_user=[ UserCostSummary( - user_id=r.get("user_id"), - email=_mask_email(r.get("email")), - total_cost_microdollars=r["total_cost"], - total_input_tokens=r["total_input_tokens"], - total_output_tokens=r["total_output_tokens"], - request_count=r["request_count"], + user_id=r.get("userId"), + email=_mask_email(email_by_user_id.get(r.get("userId") or "")), + total_cost_microdollars=_si(r, "costMicrodollars"), + total_input_tokens=_si(r, "inputTokens"), + total_output_tokens=_si(r, "outputTokens"), + request_count=_ca(r), ) - for r in by_user_rows + for r in by_user_groups ], total_cost_microdollars=total_cost, total_requests=total_requests, @@ -365,73 +370,41 @@ async def get_platform_cost_logs( ) -> tuple[list[CostLogRow], int]: if start is None: start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) - where_sql, params = _build_where( - start, end, provider, user_id, "p", model, block_name, tracking_type - ) + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type + ) offset = (page - 1) * page_size - limit_idx = len(params) + 1 - offset_idx = len(params) + 2 - count_rows, rows = await asyncio.gather( - query_raw_with_schema( - f""" - SELECT COUNT(*)::bigint AS cnt - FROM {{schema_prefix}}"PlatformCostLog" p - WHERE {where_sql} - """, - *params, - ), - query_raw_with_schema( - f""" - SELECT - p."id", - p."createdAt" AS created_at, - p."userId" AS user_id, - u."email", - p."graphExecId" AS graph_exec_id, - p."nodeExecId" AS node_exec_id, - p."blockName" AS block_name, - p."provider", - p."trackingType" AS tracking_type, - p."costMicrodollars" AS cost_microdollars, - p."inputTokens" AS input_tokens, - p."outputTokens" AS output_tokens, - p."cacheReadTokens" AS cache_read_tokens, - p."cacheCreationTokens" AS cache_creation_tokens, - p."duration", - p."model" - FROM {{schema_prefix}}"PlatformCostLog" p - LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId" - WHERE {where_sql} - ORDER BY p."createdAt" DESC, p."id" DESC - LIMIT ${limit_idx} OFFSET ${offset_idx} - """, - *params, - page_size, - offset, + total, rows = await asyncio.gather( + PrismaLog.prisma().count(where=where), + PrismaLog.prisma().find_many( + where=where, + include={"User": True}, + order=[{"createdAt": "desc"}, {"id": "desc"}], + take=page_size, + skip=offset, ), ) - total = count_rows[0]["cnt"] if count_rows else 0 logs = [ CostLogRow( - id=r["id"], - created_at=r["created_at"], - user_id=r.get("user_id"), - email=_mask_email(r.get("email")), - graph_exec_id=r.get("graph_exec_id"), - node_exec_id=r.get("node_exec_id"), - block_name=r["block_name"], - provider=r["provider"], - tracking_type=r.get("tracking_type"), - cost_microdollars=r.get("cost_microdollars"), - input_tokens=r.get("input_tokens"), - output_tokens=r.get("output_tokens"), - cache_read_tokens=r.get("cache_read_tokens"), - cache_creation_tokens=r.get("cache_creation_tokens"), - duration=r.get("duration"), - model=r.get("model"), + id=r.id, + created_at=r.createdAt, + user_id=r.userId, + email=_mask_email(r.User.email if r.User else None), + graph_exec_id=r.graphExecId, + node_exec_id=r.nodeExecId, + block_name=r.blockName or "", + provider=r.provider, + tracking_type=r.trackingType, + cost_microdollars=r.costMicrodollars, + input_tokens=r.inputTokens, + output_tokens=r.outputTokens, + cache_read_tokens=getattr(r, "cacheReadTokens", None), + cache_creation_tokens=getattr(r, "cacheCreationTokens", None), + duration=r.duration, + model=r.model, ) for r in rows ] @@ -457,38 +430,16 @@ async def get_platform_cost_logs_for_export( """ if start is None: start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS) - where_sql, params = _build_where( - start, end, provider, user_id, "p", model, block_name, tracking_type - ) - limit_idx = len(params) + 1 - rows = await query_raw_with_schema( - f""" - SELECT - p."id", - p."createdAt" AS created_at, - p."userId" AS user_id, - u."email", - p."graphExecId" AS graph_exec_id, - p."nodeExecId" AS node_exec_id, - p."blockName" AS block_name, - p."provider", - p."trackingType" AS tracking_type, - p."costMicrodollars" AS cost_microdollars, - p."inputTokens" AS input_tokens, - p."outputTokens" AS output_tokens, - p."cacheReadTokens" AS cache_read_tokens, - p."cacheCreationTokens" AS cache_creation_tokens, - p."duration", - p."model" - FROM {{schema_prefix}}"PlatformCostLog" p - LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId" - WHERE {where_sql} - ORDER BY p."createdAt" DESC, p."id" DESC - LIMIT ${limit_idx} - """, - *params, - EXPORT_MAX_ROWS + 1, + where = _build_prisma_where( + start, end, provider, user_id, model, block_name, tracking_type + ) + + rows = await PrismaLog.prisma().find_many( + where=where, + include={"User": True}, + order=[{"createdAt": "desc"}, {"id": "desc"}], + take=EXPORT_MAX_ROWS + 1, ) truncated = len(rows) > EXPORT_MAX_ROWS @@ -496,22 +447,80 @@ async def get_platform_cost_logs_for_export( return [ CostLogRow( - id=r["id"], - created_at=r["created_at"], - user_id=r.get("user_id"), - email=_mask_email(r.get("email")), - graph_exec_id=r.get("graph_exec_id"), - node_exec_id=r.get("node_exec_id"), - block_name=r["block_name"], - provider=r["provider"], - tracking_type=r.get("tracking_type"), - cost_microdollars=r.get("cost_microdollars"), - input_tokens=r.get("input_tokens"), - output_tokens=r.get("output_tokens"), - cache_read_tokens=r.get("cache_read_tokens"), - cache_creation_tokens=r.get("cache_creation_tokens"), - duration=r.get("duration"), - model=r.get("model"), + id=r.id, + created_at=r.createdAt, + user_id=r.userId, + email=_mask_email(r.User.email if r.User else None), + graph_exec_id=r.graphExecId, + node_exec_id=r.nodeExecId, + block_name=r.blockName or "", + provider=r.provider, + tracking_type=r.trackingType, + cost_microdollars=r.costMicrodollars, + input_tokens=r.inputTokens, + output_tokens=r.outputTokens, + cache_read_tokens=getattr(r, "cacheReadTokens", None), + cache_creation_tokens=getattr(r, "cacheCreationTokens", None), + duration=r.duration, + model=r.model, ) for r in rows ], truncated + + +# --------------------------------------------------------------------------- +# Helpers kept for backward-compatibility with existing tests. +# New code should not use these — use _build_prisma_where instead. +# --------------------------------------------------------------------------- + + +def _build_where( + start: datetime | None, + end: datetime | None, + provider: str | None, + user_id: str | None, + table_alias: str = "", + model: str | None = None, + block_name: str | None = None, + tracking_type: str | None = None, +) -> tuple[str, list[Any]]: + """Legacy SQL WHERE builder — retained so existing unit tests still pass. + + Only used by tests that verify the SQL-string generation logic. All + production code uses _build_prisma_where instead. + """ + prefix = f"{table_alias}." if table_alias else "" + clauses: list[str] = [] + params: list[Any] = [] + idx = 1 + + if start: + clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz') + params.append(start) + idx += 1 + if end: + clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz') + params.append(end) + idx += 1 + if provider: + clauses.append(f'{prefix}"provider" = ${idx}') + params.append(provider.lower()) + idx += 1 + if user_id: + clauses.append(f'{prefix}"userId" = ${idx}') + params.append(user_id) + idx += 1 + if model: + clauses.append(f'{prefix}"model" = ${idx}') + params.append(model) + idx += 1 + if block_name: + clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})') + params.append(block_name) + idx += 1 + if tracking_type: + clauses.append(f'{prefix}"trackingType" = ${idx}') + params.append(tracking_type) + idx += 1 + + return (" AND ".join(clauses) if clauses else "TRUE", params) diff --git a/autogpt_platform/backend/backend/data/platform_cost_test.py b/autogpt_platform/backend/backend/data/platform_cost_test.py index 758e97d37b..dacd2c42ea 100644 --- a/autogpt_platform/backend/backend/data/platform_cost_test.py +++ b/autogpt_platform/backend/backend/data/platform_cost_test.py @@ -1,7 +1,7 @@ """Unit tests for helpers and async functions in platform_cost module.""" from datetime import datetime, timezone -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from prisma import Json @@ -225,6 +225,41 @@ class TestLogPlatformCostSafe: mock_create.assert_awaited_once() +def _make_group_by_row( + provider: str = "openai", + tracking_type: str | None = "tokens", + model: str | None = None, + cost: int = 5000, + input_tokens: int = 1000, + output_tokens: int = 500, + cache_read_tokens: int = 0, + cache_creation_tokens: int = 0, + duration: float = 10.5, + tracking_amount: float = 0.0, + count: int = 3, + user_id: str | None = None, +) -> dict: + row: dict = { + "_sum": { + "costMicrodollars": cost, + "inputTokens": input_tokens, + "outputTokens": output_tokens, + "cacheReadTokens": cache_read_tokens, + "cacheCreationTokens": cache_creation_tokens, + "duration": duration, + "trackingAmount": tracking_amount, + }, + "_count": {"_all": count}, + } + if user_id is not None: + row["userId"] = user_id + else: + row["provider"] = provider + row["trackingType"] = tracking_type + row["model"] = model + return row + + class TestGetPlatformCostDashboard: def setup_method(self): # @cached stores results in-process; clear between tests to avoid bleed. @@ -232,35 +267,44 @@ class TestGetPlatformCostDashboard: @pytest.mark.asyncio async def test_returns_dashboard_with_data(self): - provider_rows = [ - { - "provider": "openai", - "tracking_type": "tokens", - "total_cost": 5000, - "total_input_tokens": 1000, - "total_output_tokens": 500, - "total_duration": 10.5, - "request_count": 3, - } - ] - user_rows = [ - { - "user_id": "u1", - "email": "a@b.com", - "total_cost": 5000, - "total_input_tokens": 1000, - "total_output_tokens": 500, - "request_count": 3, - } - ] - # Dashboard runs 4 queries: by_provider, by_user, COUNT(DISTINCT userId), - # and a separate total aggregate (total_cost + request_count with no LIMIT). - agg_rows = [{"total_cost": 5000, "request_count": 3}] - mock_query = AsyncMock( - side_effect=[provider_rows, user_rows, [{"cnt": 1}], agg_rows] + provider_row = _make_group_by_row( + provider="openai", + tracking_type="tokens", + cost=5000, + input_tokens=1000, + output_tokens=500, + duration=10.5, + count=3, ) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + user_row = _make_group_by_row(user_id="u1", cost=5000, count=3) + + mock_user = MagicMock() + mock_user.id = "u1" + mock_user.email = "a@b.com" + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock( + side_effect=[ + [provider_row], # by_provider + [user_row], # by_user + [{"userId": "u1"}], # distinct users + [provider_row], # total agg + ] + ) + mock_actions.find_many = AsyncMock(return_value=[mock_user]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): dashboard = await get_platform_cost_dashboard() + assert dashboard.total_cost_microdollars == 5000 assert dashboard.total_requests == 3 assert dashboard.total_users == 1 @@ -272,10 +316,67 @@ class TestGetPlatformCostDashboard: assert dashboard.by_user[0].email == "a***@b.com" @pytest.mark.asyncio - async def test_returns_empty_dashboard(self): - mock_query = AsyncMock(side_effect=[[], [], [], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + async def test_cache_tokens_aggregated_not_hardcoded(self): + """cache_read_tokens and cache_creation_tokens must be read from the + DB aggregation, not hardcoded to 0 (regression guard for Sentry report).""" + provider_row = _make_group_by_row( + provider="anthropic", + tracking_type="tokens", + cost=1000, + input_tokens=800, + output_tokens=200, + cache_read_tokens=400, + cache_creation_tokens=100, + count=1, + ) + user_row = _make_group_by_row(user_id="u2", cost=1000, count=1) + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock( + side_effect=[ + [provider_row], # by_provider + [user_row], # by_user + [{"userId": "u2"}], # distinct users + [provider_row], # total agg + ] + ) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): dashboard = await get_platform_cost_dashboard() + + assert len(dashboard.by_provider) == 1 + row = dashboard.by_provider[0] + assert row.total_cache_read_tokens == 400 + assert row.total_cache_creation_tokens == 100 + + @pytest.mark.asyncio + async def test_returns_empty_dashboard(self): + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []]) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): + dashboard = await get_platform_cost_dashboard() + assert dashboard.total_cost_microdollars == 0 assert dashboard.total_requests == 0 assert dashboard.total_users == 0 @@ -285,160 +386,228 @@ class TestGetPlatformCostDashboard: @pytest.mark.asyncio async def test_passes_filters_to_queries(self): start = datetime(2026, 1, 1, tzinfo=timezone.utc) - mock_query = AsyncMock(side_effect=[[], [], [], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + + mock_actions = MagicMock() + mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []]) + mock_actions.find_many = AsyncMock(return_value=[]) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch( + "backend.data.platform_cost.PrismaUser.prisma", + return_value=mock_actions, + ), + ): await get_platform_cost_dashboard( start=start, provider="openai", user_id="u1" ) - assert mock_query.await_count == 4 - first_call_sql = mock_query.call_args_list[0][0][0] - assert "createdAt" in first_call_sql + + # group_by called 4 times (by_provider, by_user, distinct users, totals) + assert mock_actions.group_by.await_count == 4 + # The where dict passed to the first call should include createdAt + first_call_kwargs = mock_actions.group_by.call_args_list[0][1] + assert "createdAt" in first_call_kwargs.get("where", {}) + + +def _make_prisma_log_row( + i: int = 0, + user_email: str | None = None, +) -> MagicMock: + row = MagicMock() + row.id = f"log-{i}" + row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc) + row.userId = "u1" + row.graphExecId = None + row.nodeExecId = None + row.blockName = "TestBlock" + row.provider = "openai" + row.trackingType = "tokens" + row.costMicrodollars = 1000 + row.inputTokens = 10 + row.outputTokens = 5 + row.duration = 0.5 + row.model = "gpt-4" + # cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients + row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None}) + if user_email is not None: + row.User = MagicMock() + row.User.email = user_email + else: + row.User = None + return row class TestGetPlatformCostLogs: @pytest.mark.asyncio async def test_returns_logs_and_total(self): - count_rows = [{"cnt": 1}] - log_rows = [ - { - "id": "log-1", - "created_at": datetime(2026, 3, 1, tzinfo=timezone.utc), - "user_id": "u1", - "email": "a@b.com", - "graph_exec_id": "g1", - "node_exec_id": "n1", - "block_name": "TestBlock", - "provider": "openai", - "tracking_type": "tokens", - "cost_microdollars": 5000, - "input_tokens": 100, - "output_tokens": 50, - "duration": 1.5, - "model": "gpt-4", - } - ] - mock_query = AsyncMock(side_effect=[count_rows, log_rows]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + row = _make_prisma_log_row(0, user_email="a@b.com") + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=1) + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs(page=1, page_size=10) + assert total == 1 assert len(logs) == 1 - assert logs[0].id == "log-1" + assert logs[0].id == "log-0" assert logs[0].provider == "openai" assert logs[0].model == "gpt-4" @pytest.mark.asyncio async def test_returns_empty_when_no_data(self): - mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=0) + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs() + assert total == 0 assert logs == [] @pytest.mark.asyncio async def test_pagination_offset(self): - mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): - logs, total = await get_platform_cost_logs(page=3, page_size=25) - assert total == 100 - second_call_args = mock_query.call_args_list[1][0] - assert 25 in second_call_args # page_size - assert 50 in second_call_args # offset = (3-1) * 25 + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=100) + mock_actions.find_many = AsyncMock(return_value=[]) - @pytest.mark.asyncio - async def test_empty_count_returns_zero(self): - mock_query = AsyncMock(side_effect=[[], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): - logs, total = await get_platform_cost_logs() - assert total == 0 + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): + logs, total = await get_platform_cost_logs(page=3, page_size=25) + + assert total == 100 + find_many_call = mock_actions.find_many.call_args[1] + assert find_many_call["take"] == 25 + assert find_many_call["skip"] == 50 # (3-1) * 25 @pytest.mark.asyncio async def test_explicit_start_skips_default(self): start = datetime(2026, 1, 1, tzinfo=timezone.utc) - mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.count = AsyncMock(return_value=0) + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, total = await get_platform_cost_logs(start=start) + assert total == 0 - - -def _make_log_row(i: int = 0) -> dict: - return { - "id": f"log-{i}", - "created_at": datetime(2026, 3, 1, tzinfo=timezone.utc), - "user_id": "u1", - "email": None, - "graph_exec_id": None, - "node_exec_id": None, - "block_name": "TestBlock", - "provider": "openai", - "tracking_type": "tokens", - "cost_microdollars": 1000, - "input_tokens": 10, - "output_tokens": 5, - "duration": 0.5, - "model": "gpt-4", - "cache_read_tokens": None, - "cache_creation_tokens": None, - } + where = mock_actions.count.call_args[1]["where"] + # start provided — should appear in the where filter + assert "createdAt" in where class TestGetPlatformCostLogsForExport: @pytest.mark.asyncio async def test_returns_logs_not_truncated(self): - rows = [_make_log_row(0)] - mock_query = AsyncMock(return_value=rows) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + row = _make_prisma_log_row(0) + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, truncated = await get_platform_cost_logs_for_export() + assert len(logs) == 1 assert truncated is False assert logs[0].id == "log-0" @pytest.mark.asyncio async def test_returns_empty_not_truncated(self): - mock_query = AsyncMock(return_value=[]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, truncated = await get_platform_cost_logs_for_export() + assert logs == [] assert truncated is False @pytest.mark.asyncio async def test_truncates_at_export_max_rows(self): - rows = [_make_log_row(i) for i in range(3)] - mock_query = AsyncMock(return_value=rows) - with patch( - "backend.data.platform_cost.query_raw_with_schema", new=mock_query - ), patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2): + rows = [_make_prisma_log_row(i) for i in range(3)] + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=rows) + + with ( + patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ), + patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2), + ): logs, truncated = await get_platform_cost_logs_for_export() + assert len(logs) == 2 assert truncated is True @pytest.mark.asyncio async def test_passes_model_block_tracking_filters(self): - mock_query = AsyncMock(return_value=[]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): await get_platform_cost_logs_for_export( model="gpt-4", block_name="LLMBlock", tracking_type="tokens" ) - call_args = mock_query.call_args[0] - assert "gpt-4" in call_args - assert "LLMBlock" in call_args - assert "tokens" in call_args + + where = mock_actions.find_many.call_args[1]["where"] + assert where.get("model") == "gpt-4" + assert where.get("trackingType") == "tokens" + # blockName uses a dict filter for case-insensitive match + assert "blockName" in where @pytest.mark.asyncio async def test_maps_cache_tokens(self): - row = _make_log_row(0) - row["cache_read_tokens"] = 50 - row["cache_creation_tokens"] = 25 - mock_query = AsyncMock(return_value=[row]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + row = _make_prisma_log_row(0) + row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25}) + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[row]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, _ = await get_platform_cost_logs_for_export() + assert logs[0].cache_read_tokens == 50 assert logs[0].cache_creation_tokens == 25 @pytest.mark.asyncio async def test_explicit_start_skips_default(self): start = datetime(2026, 1, 1, tzinfo=timezone.utc) - mock_query = AsyncMock(return_value=[]) - with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query): + mock_actions = MagicMock() + mock_actions.find_many = AsyncMock(return_value=[]) + + with patch( + "backend.data.platform_cost.PrismaLog.prisma", + return_value=mock_actions, + ): logs, truncated = await get_platform_cost_logs_for_export(start=start) + assert logs == [] assert truncated is False + where = mock_actions.find_many.call_args[1]["where"] + assert "createdAt" in where