mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/copilot-pending-messages
This commit is contained in:
100
autogpt_platform/analytics/queries/platform_cost_log.sql
Normal file
100
autogpt_platform/analytics/queries/platform_cost_log.sql
Normal file
@@ -0,0 +1,100 @@
|
||||
-- =============================================================
|
||||
-- View: analytics.platform_cost_log
|
||||
-- Looker source alias: ds115 | Charts: 0
|
||||
-- =============================================================
|
||||
-- DESCRIPTION
|
||||
-- One row per platform cost log entry (last 90 days).
|
||||
-- Tracks real API spend at the call level: provider, model,
|
||||
-- token counts (including Anthropic cache tokens), cost in
|
||||
-- microdollars, and the block/execution that incurred the cost.
|
||||
-- Joins the User table to provide email for per-user breakdowns.
|
||||
--
|
||||
-- SOURCE TABLES
|
||||
-- platform.PlatformCostLog — Per-call cost records
|
||||
-- platform.User — User email
|
||||
--
|
||||
-- OUTPUT COLUMNS
|
||||
-- id TEXT Log entry UUID
|
||||
-- createdAt TIMESTAMPTZ When the cost was recorded
|
||||
-- userId TEXT User who incurred the cost (nullable)
|
||||
-- email TEXT User email (nullable)
|
||||
-- graphExecId TEXT Graph execution UUID (nullable)
|
||||
-- nodeExecId TEXT Node execution UUID (nullable)
|
||||
-- blockName TEXT Block that made the API call (nullable)
|
||||
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
|
||||
-- model TEXT Model name (nullable)
|
||||
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
|
||||
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
|
||||
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
|
||||
-- inputTokens INT Prompt/input tokens (nullable)
|
||||
-- outputTokens INT Completion/output tokens (nullable)
|
||||
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
|
||||
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
|
||||
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
|
||||
-- duration FLOAT API call duration in seconds (nullable)
|
||||
--
|
||||
-- WINDOW
|
||||
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
|
||||
--
|
||||
-- EXAMPLE QUERIES
|
||||
-- -- Total spend by provider (last 90 days)
|
||||
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Spend by model
|
||||
-- SELECT provider, model, SUM("costUsd") AS total_usd,
|
||||
-- SUM("inputTokens") AS input_tokens,
|
||||
-- SUM("outputTokens") AS output_tokens
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE model IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
|
||||
--
|
||||
-- -- Top 20 users by spend
|
||||
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE "userId" IS NOT NULL
|
||||
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
|
||||
--
|
||||
-- -- Daily spend trend
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("costUsd") AS daily_usd,
|
||||
-- COUNT(*) AS calls
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
--
|
||||
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
|
||||
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
|
||||
-- SUM("cacheReadTokens")::float /
|
||||
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
|
||||
-- FROM analytics.platform_cost_log
|
||||
-- WHERE provider = 'anthropic'
|
||||
-- GROUP BY 1 ORDER BY 1;
|
||||
-- =============================================================
|
||||
|
||||
SELECT
|
||||
p."id" AS id,
|
||||
p."createdAt" AS createdAt,
|
||||
p."userId" AS userId,
|
||||
u."email" AS email,
|
||||
p."graphExecId" AS graphExecId,
|
||||
p."nodeExecId" AS nodeExecId,
|
||||
p."blockName" AS blockName,
|
||||
p."provider" AS provider,
|
||||
p."model" AS model,
|
||||
p."trackingType" AS trackingType,
|
||||
p."costMicrodollars" AS costMicrodollars,
|
||||
p."costMicrodollars"::float / 1000000.0 AS costUsd,
|
||||
p."inputTokens" AS inputTokens,
|
||||
p."outputTokens" AS outputTokens,
|
||||
p."cacheReadTokens" AS cacheReadTokens,
|
||||
p."cacheCreationTokens" AS cacheCreationTokens,
|
||||
CASE
|
||||
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
|
||||
THEN p."inputTokens" + p."outputTokens"
|
||||
ELSE NULL
|
||||
END AS totalTokens,
|
||||
p."duration" AS duration
|
||||
FROM platform."PlatformCostLog" p
|
||||
LEFT JOIN platform."User" u ON u."id" = p."userId"
|
||||
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'
|
||||
@@ -887,6 +887,21 @@ async def llm_call(
|
||||
provider = llm_model.metadata.provider
|
||||
context_window = llm_model.context_window
|
||||
|
||||
# Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key
|
||||
# is configured, route direct-Anthropic models through OpenRouter instead. This
|
||||
# gives us the x-total-cost header for free, so provider_cost is always populated
|
||||
# without manual token-rate arithmetic.
|
||||
or_key = settings.secrets.open_router_api_key
|
||||
or_model_id: str | None = None
|
||||
if provider == "anthropic" and or_key:
|
||||
provider = "open_router"
|
||||
credentials = APIKeyCredentials(
|
||||
provider=ProviderName.OPEN_ROUTER,
|
||||
title="OpenRouter (auto)",
|
||||
api_key=SecretStr(or_key),
|
||||
)
|
||||
or_model_id = f"anthropic/{llm_model.value}"
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
messages=prompt,
|
||||
@@ -1134,7 +1149,7 @@ async def llm_call(
|
||||
"HTTP-Referer": "https://agpt.co",
|
||||
"X-Title": "AutoGPT",
|
||||
},
|
||||
model=llm_model.value,
|
||||
model=or_model_id or llm_model.value,
|
||||
messages=prompt, # type: ignore
|
||||
max_tokens=max_tokens,
|
||||
tools=tools_param, # type: ignore
|
||||
|
||||
@@ -77,7 +77,11 @@ class TestLLMStatsTracking:
|
||||
mock_response.usage = mock_usage
|
||||
mock_response.stop_reason = "end_turn"
|
||||
|
||||
with patch("anthropic.AsyncAnthropic") as mock_anthropic:
|
||||
with (
|
||||
patch("anthropic.AsyncAnthropic") as mock_anthropic,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = ""
|
||||
mock_client = AsyncMock()
|
||||
mock_anthropic.return_value = mock_client
|
||||
mock_client.messages.create = AsyncMock(return_value=mock_response)
|
||||
@@ -96,6 +100,56 @@ class TestLLMStatsTracking:
|
||||
assert response.cache_creation_tokens == 50
|
||||
assert response.response == "Test anthropic response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anthropic_routes_through_openrouter_when_key_present(self):
|
||||
"""When open_router_api_key is set, Anthropic models route via OpenRouter."""
|
||||
from pydantic import SecretStr
|
||||
|
||||
import backend.blocks.llm as llm
|
||||
from backend.data.model import APIKeyCredentials
|
||||
|
||||
anthropic_creds = APIKeyCredentials(
|
||||
id="test-anthropic-id",
|
||||
provider="anthropic",
|
||||
api_key=SecretStr("mock-anthropic-key"),
|
||||
title="Mock Anthropic key",
|
||||
)
|
||||
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message.content = "routed response"
|
||||
mock_choice.message.tool_calls = None
|
||||
|
||||
mock_usage = MagicMock()
|
||||
mock_usage.prompt_tokens = 10
|
||||
mock_usage.completion_tokens = 5
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = mock_usage
|
||||
|
||||
mock_create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with (
|
||||
patch("openai.AsyncOpenAI") as mock_openai,
|
||||
patch("backend.blocks.llm.settings") as mock_settings,
|
||||
):
|
||||
mock_settings.secrets.open_router_api_key = "sk-or-test-key"
|
||||
mock_client = MagicMock()
|
||||
mock_openai.return_value = mock_client
|
||||
mock_client.chat.completions.create = mock_create
|
||||
|
||||
await llm.llm_call(
|
||||
credentials=anthropic_creds,
|
||||
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
|
||||
prompt=[{"role": "user", "content": "Hello"}],
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
# Verify OpenAI client was used (not Anthropic SDK) and model was prefixed
|
||||
mock_openai.assert_called_once()
|
||||
call_kwargs = mock_create.call_args.kwargs
|
||||
assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_structured_response_block_tracks_stats(self):
|
||||
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""
|
||||
|
||||
@@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import PlatformCostLog as PrismaLog
|
||||
from prisma.types import PlatformCostLogCreateInput
|
||||
from prisma.models import User as PrismaUser
|
||||
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.db import query_raw_with_schema
|
||||
from backend.util.cache import cached
|
||||
from backend.util.json import SafeJson
|
||||
|
||||
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MICRODOLLARS_PER_USD = 1_000_000
|
||||
|
||||
# Dashboard query limits — keep in sync with the SQL queries below
|
||||
# Dashboard query limits
|
||||
MAX_PROVIDER_ROWS = 500
|
||||
MAX_USER_ROWS = 100
|
||||
|
||||
@@ -169,53 +169,61 @@ class PlatformCostDashboard(BaseModel):
|
||||
total_users: int
|
||||
|
||||
|
||||
def _build_where(
|
||||
def _si(row: dict, field: str) -> int:
|
||||
"""Extract an integer from a Prisma group_by _sum dict.
|
||||
|
||||
Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int.
|
||||
"""
|
||||
return int((row.get("_sum") or {}).get(field) or 0)
|
||||
|
||||
|
||||
def _sf(row: dict, field: str) -> float:
|
||||
"""Extract a float from a Prisma group_by _sum dict."""
|
||||
return float((row.get("_sum") or {}).get(field) or 0.0)
|
||||
|
||||
|
||||
def _ca(row: dict) -> int:
|
||||
"""Extract _count._all from a Prisma group_by row."""
|
||||
c = row.get("_count") or {}
|
||||
return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0)
|
||||
|
||||
|
||||
def _build_prisma_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list[Any]]:
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
) -> PlatformCostLogWhereInput:
|
||||
"""Build a Prisma WhereInput for PlatformCostLog filters."""
|
||||
where: PlatformCostLogWhereInput = {}
|
||||
|
||||
if start and end:
|
||||
where["createdAt"] = {"gte": start, "lte": end}
|
||||
elif start:
|
||||
where["createdAt"] = {"gte": start}
|
||||
elif end:
|
||||
where["createdAt"] = {"lte": end}
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
# Provider names are normalized to lowercase at write time so a plain
|
||||
# equality check is sufficient and the (provider, createdAt) index is used.
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
if model:
|
||||
clauses.append(f'{prefix}"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
if block_name:
|
||||
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
if tracking_type:
|
||||
clauses.append(f'{prefix}"trackingType" = ${idx}')
|
||||
params.append(tracking_type)
|
||||
idx += 1
|
||||
where["provider"] = provider.lower()
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
if user_id:
|
||||
where["userId"] = user_id
|
||||
|
||||
if model:
|
||||
where["model"] = model
|
||||
|
||||
if block_name:
|
||||
# Case-insensitive match — mirrors the original LOWER() SQL filter.
|
||||
where["blockName"] = {"equals": block_name, "mode": "insensitive"}
|
||||
|
||||
if tracking_type:
|
||||
where["trackingType"] = tracking_type
|
||||
|
||||
return where
|
||||
|
||||
|
||||
@cached(ttl_seconds=30)
|
||||
@@ -241,110 +249,107 @@ async def get_platform_cost_dashboard(
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_p, params_p = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
by_provider_rows, by_user_rows, total_user_rows, total_agg_rows = (
|
||||
sum_fields = {
|
||||
"costMicrodollars": True,
|
||||
"inputTokens": True,
|
||||
"outputTokens": True,
|
||||
"cacheReadTokens": True,
|
||||
"cacheCreationTokens": True,
|
||||
"duration": True,
|
||||
"trackingAmount": True,
|
||||
}
|
||||
|
||||
# Run all four aggregation queries in parallel.
|
||||
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
|
||||
await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."model",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(SUM(p."cacheReadTokens"), 0)::bigint AS total_cache_read_tokens,
|
||||
COALESCE(SUM(p."cacheCreationTokens"), 0)::bigint AS total_cache_creation_tokens,
|
||||
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
|
||||
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
GROUP BY p."provider", p."trackingType", p."model"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_PROVIDER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
|
||||
# sort by total cost descending in Python after fetch.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider", "trackingType", "model"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_p}
|
||||
GROUP BY p."userId", u."email"
|
||||
ORDER BY total_cost DESC
|
||||
LIMIT {MAX_USER_ROWS}
|
||||
""",
|
||||
*params_p,
|
||||
# userId aggregation — emails fetched separately below.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
sum=sum_fields,
|
||||
count=True,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
# Distinct user count: group by userId, count groups.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["userId"],
|
||||
where=where,
|
||||
count=True,
|
||||
),
|
||||
# Separate aggregate query so dashboard totals are never derived
|
||||
# from the capped by_provider_rows list. With model-level grouping,
|
||||
# MAX_PROVIDER_ROWS is hit more easily; summing the capped rows
|
||||
# would silently undercount once >500 (provider, type, model) exist.
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
|
||||
COUNT(*)::bigint AS request_count
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_p}
|
||||
""",
|
||||
*params_p,
|
||||
# Total aggregate: group by provider (no limit) to sum across all
|
||||
# matching rows. Summed in Python to get grand totals.
|
||||
PrismaLog.prisma().group_by(
|
||||
by=["provider"],
|
||||
where=where,
|
||||
sum={"costMicrodollars": True},
|
||||
count=True,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
|
||||
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
|
||||
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
|
||||
total_cost = int(total_agg_rows[0]["total_cost"]) if total_agg_rows else 0
|
||||
total_requests = int(total_agg_rows[0]["request_count"]) if total_agg_rows else 0
|
||||
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
|
||||
by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
|
||||
by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS]
|
||||
|
||||
# Sort by_user by total cost descending and cap at MAX_USER_ROWS.
|
||||
by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
|
||||
by_user_groups = by_user_groups[:MAX_USER_ROWS]
|
||||
|
||||
# Batch-fetch emails for the users in by_user.
|
||||
user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None]
|
||||
email_by_user_id: dict[str, str | None] = {}
|
||||
if user_ids:
|
||||
users = await PrismaUser.prisma().find_many(
|
||||
where={"id": {"in": user_ids}},
|
||||
)
|
||||
email_by_user_id = {u.id: u.email for u in users}
|
||||
|
||||
# Total distinct users — exclude the NULL-userId group (deleted users).
|
||||
total_users = len([g for g in total_user_groups if g.get("userId") is not None])
|
||||
|
||||
# Grand totals — sum across all provider groups (no LIMIT applied above).
|
||||
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
|
||||
total_requests = sum(_ca(r) for r in total_agg_groups)
|
||||
|
||||
return PlatformCostDashboard(
|
||||
by_provider=[
|
||||
ProviderCostSummary(
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
tracking_type=r.get("trackingType"),
|
||||
model=r.get("model"),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
total_cache_read_tokens=r.get("total_cache_read_tokens", 0),
|
||||
total_cache_creation_tokens=r.get("total_cache_creation_tokens", 0),
|
||||
total_duration_seconds=r.get("total_duration", 0.0),
|
||||
total_tracking_amount=r.get("total_tracking_amount", 0.0),
|
||||
request_count=r["request_count"],
|
||||
total_cost_microdollars=_si(r, "costMicrodollars"),
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
total_cache_read_tokens=_si(r, "cacheReadTokens"),
|
||||
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
|
||||
total_duration_seconds=_sf(r, "duration"),
|
||||
total_tracking_amount=_sf(r, "trackingAmount"),
|
||||
request_count=_ca(r),
|
||||
)
|
||||
for r in by_provider_rows
|
||||
for r in by_provider_groups
|
||||
],
|
||||
by_user=[
|
||||
UserCostSummary(
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
total_cost_microdollars=r["total_cost"],
|
||||
total_input_tokens=r["total_input_tokens"],
|
||||
total_output_tokens=r["total_output_tokens"],
|
||||
request_count=r["request_count"],
|
||||
user_id=r.get("userId"),
|
||||
email=_mask_email(email_by_user_id.get(r.get("userId") or "")),
|
||||
total_cost_microdollars=_si(r, "costMicrodollars"),
|
||||
total_input_tokens=_si(r, "inputTokens"),
|
||||
total_output_tokens=_si(r, "outputTokens"),
|
||||
request_count=_ca(r),
|
||||
)
|
||||
for r in by_user_rows
|
||||
for r in by_user_groups
|
||||
],
|
||||
total_cost_microdollars=total_cost,
|
||||
total_requests=total_requests,
|
||||
@@ -365,73 +370,41 @@ async def get_platform_cost_logs(
|
||||
) -> tuple[list[CostLogRow], int]:
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
)
|
||||
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
offset = (page - 1) * page_size
|
||||
limit_idx = len(params) + 1
|
||||
offset_idx = len(params) + 2
|
||||
|
||||
count_rows, rows = await asyncio.gather(
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT COUNT(*)::bigint AS cnt
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
WHERE {where_sql}
|
||||
""",
|
||||
*params,
|
||||
),
|
||||
query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."cacheReadTokens" AS cache_read_tokens,
|
||||
p."cacheCreationTokens" AS cache_creation_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx} OFFSET ${offset_idx}
|
||||
""",
|
||||
*params,
|
||||
page_size,
|
||||
offset,
|
||||
total, rows = await asyncio.gather(
|
||||
PrismaLog.prisma().count(where=where),
|
||||
PrismaLog.prisma().find_many(
|
||||
where=where,
|
||||
include={"User": True},
|
||||
order=[{"createdAt": "desc"}, {"id": "desc"}],
|
||||
take=page_size,
|
||||
skip=offset,
|
||||
),
|
||||
)
|
||||
total = count_rows[0]["cnt"] if count_rows else 0
|
||||
|
||||
logs = [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
cache_read_tokens=r.get("cache_read_tokens"),
|
||||
cache_creation_tokens=r.get("cache_creation_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
id=r.id,
|
||||
created_at=r.createdAt,
|
||||
user_id=r.userId,
|
||||
email=_mask_email(r.User.email if r.User else None),
|
||||
graph_exec_id=r.graphExecId,
|
||||
node_exec_id=r.nodeExecId,
|
||||
block_name=r.blockName or "",
|
||||
provider=r.provider,
|
||||
tracking_type=r.trackingType,
|
||||
cost_microdollars=r.costMicrodollars,
|
||||
input_tokens=r.inputTokens,
|
||||
output_tokens=r.outputTokens,
|
||||
cache_read_tokens=getattr(r, "cacheReadTokens", None),
|
||||
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
|
||||
duration=r.duration,
|
||||
model=r.model,
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
@@ -457,38 +430,16 @@ async def get_platform_cost_logs_for_export(
|
||||
"""
|
||||
if start is None:
|
||||
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
|
||||
where_sql, params = _build_where(
|
||||
start, end, provider, user_id, "p", model, block_name, tracking_type
|
||||
)
|
||||
limit_idx = len(params) + 1
|
||||
|
||||
rows = await query_raw_with_schema(
|
||||
f"""
|
||||
SELECT
|
||||
p."id",
|
||||
p."createdAt" AS created_at,
|
||||
p."userId" AS user_id,
|
||||
u."email",
|
||||
p."graphExecId" AS graph_exec_id,
|
||||
p."nodeExecId" AS node_exec_id,
|
||||
p."blockName" AS block_name,
|
||||
p."provider",
|
||||
p."trackingType" AS tracking_type,
|
||||
p."costMicrodollars" AS cost_microdollars,
|
||||
p."inputTokens" AS input_tokens,
|
||||
p."outputTokens" AS output_tokens,
|
||||
p."cacheReadTokens" AS cache_read_tokens,
|
||||
p."cacheCreationTokens" AS cache_creation_tokens,
|
||||
p."duration",
|
||||
p."model"
|
||||
FROM {{schema_prefix}}"PlatformCostLog" p
|
||||
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
|
||||
WHERE {where_sql}
|
||||
ORDER BY p."createdAt" DESC, p."id" DESC
|
||||
LIMIT ${limit_idx}
|
||||
""",
|
||||
*params,
|
||||
EXPORT_MAX_ROWS + 1,
|
||||
where = _build_prisma_where(
|
||||
start, end, provider, user_id, model, block_name, tracking_type
|
||||
)
|
||||
|
||||
rows = await PrismaLog.prisma().find_many(
|
||||
where=where,
|
||||
include={"User": True},
|
||||
order=[{"createdAt": "desc"}, {"id": "desc"}],
|
||||
take=EXPORT_MAX_ROWS + 1,
|
||||
)
|
||||
|
||||
truncated = len(rows) > EXPORT_MAX_ROWS
|
||||
@@ -496,22 +447,80 @@ async def get_platform_cost_logs_for_export(
|
||||
|
||||
return [
|
||||
CostLogRow(
|
||||
id=r["id"],
|
||||
created_at=r["created_at"],
|
||||
user_id=r.get("user_id"),
|
||||
email=_mask_email(r.get("email")),
|
||||
graph_exec_id=r.get("graph_exec_id"),
|
||||
node_exec_id=r.get("node_exec_id"),
|
||||
block_name=r["block_name"],
|
||||
provider=r["provider"],
|
||||
tracking_type=r.get("tracking_type"),
|
||||
cost_microdollars=r.get("cost_microdollars"),
|
||||
input_tokens=r.get("input_tokens"),
|
||||
output_tokens=r.get("output_tokens"),
|
||||
cache_read_tokens=r.get("cache_read_tokens"),
|
||||
cache_creation_tokens=r.get("cache_creation_tokens"),
|
||||
duration=r.get("duration"),
|
||||
model=r.get("model"),
|
||||
id=r.id,
|
||||
created_at=r.createdAt,
|
||||
user_id=r.userId,
|
||||
email=_mask_email(r.User.email if r.User else None),
|
||||
graph_exec_id=r.graphExecId,
|
||||
node_exec_id=r.nodeExecId,
|
||||
block_name=r.blockName or "",
|
||||
provider=r.provider,
|
||||
tracking_type=r.trackingType,
|
||||
cost_microdollars=r.costMicrodollars,
|
||||
input_tokens=r.inputTokens,
|
||||
output_tokens=r.outputTokens,
|
||||
cache_read_tokens=getattr(r, "cacheReadTokens", None),
|
||||
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
|
||||
duration=r.duration,
|
||||
model=r.model,
|
||||
)
|
||||
for r in rows
|
||||
], truncated
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers kept for backward-compatibility with existing tests.
|
||||
# New code should not use these — use _build_prisma_where instead.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_where(
|
||||
start: datetime | None,
|
||||
end: datetime | None,
|
||||
provider: str | None,
|
||||
user_id: str | None,
|
||||
table_alias: str = "",
|
||||
model: str | None = None,
|
||||
block_name: str | None = None,
|
||||
tracking_type: str | None = None,
|
||||
) -> tuple[str, list[Any]]:
|
||||
"""Legacy SQL WHERE builder — retained so existing unit tests still pass.
|
||||
|
||||
Only used by tests that verify the SQL-string generation logic. All
|
||||
production code uses _build_prisma_where instead.
|
||||
"""
|
||||
prefix = f"{table_alias}." if table_alias else ""
|
||||
clauses: list[str] = []
|
||||
params: list[Any] = []
|
||||
idx = 1
|
||||
|
||||
if start:
|
||||
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
|
||||
params.append(start)
|
||||
idx += 1
|
||||
if end:
|
||||
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
|
||||
params.append(end)
|
||||
idx += 1
|
||||
if provider:
|
||||
clauses.append(f'{prefix}"provider" = ${idx}')
|
||||
params.append(provider.lower())
|
||||
idx += 1
|
||||
if user_id:
|
||||
clauses.append(f'{prefix}"userId" = ${idx}')
|
||||
params.append(user_id)
|
||||
idx += 1
|
||||
if model:
|
||||
clauses.append(f'{prefix}"model" = ${idx}')
|
||||
params.append(model)
|
||||
idx += 1
|
||||
if block_name:
|
||||
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
|
||||
params.append(block_name)
|
||||
idx += 1
|
||||
if tracking_type:
|
||||
clauses.append(f'{prefix}"trackingType" = ${idx}')
|
||||
params.append(tracking_type)
|
||||
idx += 1
|
||||
|
||||
return (" AND ".join(clauses) if clauses else "TRUE", params)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Unit tests for helpers and async functions in platform_cost module."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from prisma import Json
|
||||
@@ -224,6 +224,41 @@ class TestLogPlatformCostSafe:
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
|
||||
def _make_group_by_row(
|
||||
provider: str = "openai",
|
||||
tracking_type: str | None = "tokens",
|
||||
model: str | None = None,
|
||||
cost: int = 5000,
|
||||
input_tokens: int = 1000,
|
||||
output_tokens: int = 500,
|
||||
cache_read_tokens: int = 0,
|
||||
cache_creation_tokens: int = 0,
|
||||
duration: float = 10.5,
|
||||
tracking_amount: float = 0.0,
|
||||
count: int = 3,
|
||||
user_id: str | None = None,
|
||||
) -> dict:
|
||||
row: dict = {
|
||||
"_sum": {
|
||||
"costMicrodollars": cost,
|
||||
"inputTokens": input_tokens,
|
||||
"outputTokens": output_tokens,
|
||||
"cacheReadTokens": cache_read_tokens,
|
||||
"cacheCreationTokens": cache_creation_tokens,
|
||||
"duration": duration,
|
||||
"trackingAmount": tracking_amount,
|
||||
},
|
||||
"_count": {"_all": count},
|
||||
}
|
||||
if user_id is not None:
|
||||
row["userId"] = user_id
|
||||
else:
|
||||
row["provider"] = provider
|
||||
row["trackingType"] = tracking_type
|
||||
row["model"] = model
|
||||
return row
|
||||
|
||||
|
||||
class TestGetPlatformCostDashboard:
|
||||
def setup_method(self):
|
||||
# @cached stores results in-process; clear between tests to avoid bleed.
|
||||
@@ -231,35 +266,44 @@ class TestGetPlatformCostDashboard:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_dashboard_with_data(self):
|
||||
provider_rows = [
|
||||
{
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"total_duration": 10.5,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
user_rows = [
|
||||
{
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"total_cost": 5000,
|
||||
"total_input_tokens": 1000,
|
||||
"total_output_tokens": 500,
|
||||
"request_count": 3,
|
||||
}
|
||||
]
|
||||
# Dashboard runs 4 queries: by_provider, by_user, COUNT(DISTINCT userId),
|
||||
# and a separate total aggregate (total_cost + request_count with no LIMIT).
|
||||
agg_rows = [{"total_cost": 5000, "request_count": 3}]
|
||||
mock_query = AsyncMock(
|
||||
side_effect=[provider_rows, user_rows, [{"cnt": 1}], agg_rows]
|
||||
provider_row = _make_group_by_row(
|
||||
provider="openai",
|
||||
tracking_type="tokens",
|
||||
cost=5000,
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
duration=10.5,
|
||||
count=3,
|
||||
)
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
user_row = _make_group_by_row(user_id="u1", cost=5000, count=3)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "u1"
|
||||
mock_user.email = "a@b.com"
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[{"userId": "u1"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[mock_user])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert dashboard.total_cost_microdollars == 5000
|
||||
assert dashboard.total_requests == 3
|
||||
assert dashboard.total_users == 1
|
||||
@@ -271,10 +315,67 @@ class TestGetPlatformCostDashboard:
|
||||
assert dashboard.by_user[0].email == "a***@b.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_query = AsyncMock(side_effect=[[], [], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
async def test_cache_tokens_aggregated_not_hardcoded(self):
|
||||
"""cache_read_tokens and cache_creation_tokens must be read from the
|
||||
DB aggregation, not hardcoded to 0 (regression guard for Sentry report)."""
|
||||
provider_row = _make_group_by_row(
|
||||
provider="anthropic",
|
||||
tracking_type="tokens",
|
||||
cost=1000,
|
||||
input_tokens=800,
|
||||
output_tokens=200,
|
||||
cache_read_tokens=400,
|
||||
cache_creation_tokens=100,
|
||||
count=1,
|
||||
)
|
||||
user_row = _make_group_by_row(user_id="u2", cost=1000, count=1)
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(
|
||||
side_effect=[
|
||||
[provider_row], # by_provider
|
||||
[user_row], # by_user
|
||||
[{"userId": "u2"}], # distinct users
|
||||
[provider_row], # total agg
|
||||
]
|
||||
)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert len(dashboard.by_provider) == 1
|
||||
row = dashboard.by_provider[0]
|
||||
assert row.total_cache_read_tokens == 400
|
||||
assert row.total_cache_creation_tokens == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dashboard(self):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
dashboard = await get_platform_cost_dashboard()
|
||||
|
||||
assert dashboard.total_cost_microdollars == 0
|
||||
assert dashboard.total_requests == 0
|
||||
assert dashboard.total_users == 0
|
||||
@@ -284,160 +385,228 @@ class TestGetPlatformCostDashboard:
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_filters_to_queries(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[], [], [], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaUser.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
):
|
||||
await get_platform_cost_dashboard(
|
||||
start=start, provider="openai", user_id="u1"
|
||||
)
|
||||
assert mock_query.await_count == 4
|
||||
first_call_sql = mock_query.call_args_list[0][0][0]
|
||||
assert "createdAt" in first_call_sql
|
||||
|
||||
# group_by called 4 times (by_provider, by_user, distinct users, totals)
|
||||
assert mock_actions.group_by.await_count == 4
|
||||
# The where dict passed to the first call should include createdAt
|
||||
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
|
||||
assert "createdAt" in first_call_kwargs.get("where", {})
|
||||
|
||||
|
||||
def _make_prisma_log_row(
|
||||
i: int = 0,
|
||||
user_email: str | None = None,
|
||||
) -> MagicMock:
|
||||
row = MagicMock()
|
||||
row.id = f"log-{i}"
|
||||
row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc)
|
||||
row.userId = "u1"
|
||||
row.graphExecId = None
|
||||
row.nodeExecId = None
|
||||
row.blockName = "TestBlock"
|
||||
row.provider = "openai"
|
||||
row.trackingType = "tokens"
|
||||
row.costMicrodollars = 1000
|
||||
row.inputTokens = 10
|
||||
row.outputTokens = 5
|
||||
row.duration = 0.5
|
||||
row.model = "gpt-4"
|
||||
# cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients
|
||||
row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None})
|
||||
if user_email is not None:
|
||||
row.User = MagicMock()
|
||||
row.User.email = user_email
|
||||
else:
|
||||
row.User = None
|
||||
return row
|
||||
|
||||
|
||||
class TestGetPlatformCostLogs:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_and_total(self):
|
||||
count_rows = [{"cnt": 1}]
|
||||
log_rows = [
|
||||
{
|
||||
"id": "log-1",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": "a@b.com",
|
||||
"graph_exec_id": "g1",
|
||||
"node_exec_id": "n1",
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 5000,
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"duration": 1.5,
|
||||
"model": "gpt-4",
|
||||
}
|
||||
]
|
||||
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0, user_email="a@b.com")
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=1)
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(page=1, page_size=10)
|
||||
|
||||
assert total == 1
|
||||
assert len(logs) == 1
|
||||
assert logs[0].id == "log-1"
|
||||
assert logs[0].id == "log-0"
|
||||
assert logs[0].provider == "openai"
|
||||
assert logs[0].model == "gpt-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_when_no_data(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=0)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
|
||||
assert total == 0
|
||||
assert logs == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pagination_offset(self):
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
assert total == 100
|
||||
second_call_args = mock_query.call_args_list[1][0]
|
||||
assert 25 in second_call_args # page_size
|
||||
assert 50 in second_call_args # offset = (3-1) * 25
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=100)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_count_returns_zero(self):
|
||||
mock_query = AsyncMock(side_effect=[[], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
logs, total = await get_platform_cost_logs()
|
||||
assert total == 0
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(page=3, page_size=25)
|
||||
|
||||
assert total == 100
|
||||
find_many_call = mock_actions.find_many.call_args[1]
|
||||
assert find_many_call["take"] == 25
|
||||
assert find_many_call["skip"] == 50 # (3-1) * 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_start_skips_default(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.count = AsyncMock(return_value=0)
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, total = await get_platform_cost_logs(start=start)
|
||||
|
||||
assert total == 0
|
||||
|
||||
|
||||
def _make_log_row(i: int = 0) -> dict:
|
||||
return {
|
||||
"id": f"log-{i}",
|
||||
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
|
||||
"user_id": "u1",
|
||||
"email": None,
|
||||
"graph_exec_id": None,
|
||||
"node_exec_id": None,
|
||||
"block_name": "TestBlock",
|
||||
"provider": "openai",
|
||||
"tracking_type": "tokens",
|
||||
"cost_microdollars": 1000,
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"duration": 0.5,
|
||||
"model": "gpt-4",
|
||||
"cache_read_tokens": None,
|
||||
"cache_creation_tokens": None,
|
||||
}
|
||||
where = mock_actions.count.call_args[1]["where"]
|
||||
# start provided — should appear in the where filter
|
||||
assert "createdAt" in where
|
||||
|
||||
|
||||
class TestGetPlatformCostLogsForExport:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_logs_not_truncated(self):
|
||||
rows = [_make_log_row(0)]
|
||||
mock_query = AsyncMock(return_value=rows)
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0)
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert len(logs) == 1
|
||||
assert truncated is False
|
||||
assert logs[0].id == "log-0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_not_truncated(self):
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert logs == []
|
||||
assert truncated is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncates_at_export_max_rows(self):
|
||||
rows = [_make_log_row(i) for i in range(3)]
|
||||
mock_query = AsyncMock(return_value=rows)
|
||||
with patch(
|
||||
"backend.data.platform_cost.query_raw_with_schema", new=mock_query
|
||||
), patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2):
|
||||
rows = [_make_prisma_log_row(i) for i in range(3)]
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=rows)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
),
|
||||
patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2),
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert len(logs) == 2
|
||||
assert truncated is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_model_block_tracking_filters(self):
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
await get_platform_cost_logs_for_export(
|
||||
model="gpt-4", block_name="LLMBlock", tracking_type="tokens"
|
||||
)
|
||||
call_args = mock_query.call_args[0]
|
||||
assert "gpt-4" in call_args
|
||||
assert "LLMBlock" in call_args
|
||||
assert "tokens" in call_args
|
||||
|
||||
where = mock_actions.find_many.call_args[1]["where"]
|
||||
assert where.get("model") == "gpt-4"
|
||||
assert where.get("trackingType") == "tokens"
|
||||
# blockName uses a dict filter for case-insensitive match
|
||||
assert "blockName" in where
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_maps_cache_tokens(self):
|
||||
row = _make_log_row(0)
|
||||
row["cache_read_tokens"] = 50
|
||||
row["cache_creation_tokens"] = 25
|
||||
mock_query = AsyncMock(return_value=[row])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
row = _make_prisma_log_row(0)
|
||||
row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25})
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[row])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, _ = await get_platform_cost_logs_for_export()
|
||||
|
||||
assert logs[0].cache_read_tokens == 50
|
||||
assert logs[0].cache_creation_tokens == 25
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_start_skips_default(self):
|
||||
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
|
||||
mock_query = AsyncMock(return_value=[])
|
||||
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
|
||||
mock_actions = MagicMock()
|
||||
mock_actions.find_many = AsyncMock(return_value=[])
|
||||
|
||||
with patch(
|
||||
"backend.data.platform_cost.PrismaLog.prisma",
|
||||
return_value=mock_actions,
|
||||
):
|
||||
logs, truncated = await get_platform_cost_logs_for_export(start=start)
|
||||
|
||||
assert logs == []
|
||||
assert truncated is False
|
||||
where = mock_actions.find_many.call_args[1]["where"]
|
||||
assert "createdAt" in where
|
||||
|
||||
Reference in New Issue
Block a user