Merge branch 'dev' of https://github.com/Significant-Gravitas/AutoGPT into feat/copilot-pending-messages

This commit is contained in:
majdyz
2026-04-13 07:12:49 +00:00
5 changed files with 697 additions and 350 deletions

View File

@@ -0,0 +1,100 @@
-- =============================================================
-- View: analytics.platform_cost_log
-- Looker source alias: ds115 | Charts: 0
-- =============================================================
-- DESCRIPTION
-- One row per platform cost log entry (last 90 days).
-- Tracks real API spend at the call level: provider, model,
-- token counts (including Anthropic cache tokens), cost in
-- microdollars, and the block/execution that incurred the cost.
-- Joins the User table to provide email for per-user breakdowns.
--
-- SOURCE TABLES
-- platform.PlatformCostLog — Per-call cost records
-- platform.User — User email
--
-- OUTPUT COLUMNS
-- id TEXT Log entry UUID
-- createdAt TIMESTAMPTZ When the cost was recorded
-- userId TEXT User who incurred the cost (nullable)
-- email TEXT User email (nullable)
-- graphExecId TEXT Graph execution UUID (nullable)
-- nodeExecId TEXT Node execution UUID (nullable)
-- blockName TEXT Block that made the API call (nullable)
-- provider TEXT API provider, lowercase (e.g. 'openai', 'anthropic')
-- model TEXT Model name (nullable)
-- trackingType TEXT Cost unit: 'tokens' | 'cost_usd' | 'characters' | etc.
-- costMicrodollars BIGINT Cost in microdollars (divide by 1,000,000 for USD)
-- costUsd FLOAT Cost in USD (costMicrodollars / 1,000,000)
-- inputTokens INT Prompt/input tokens (nullable)
-- outputTokens INT Completion/output tokens (nullable)
-- cacheReadTokens INT Anthropic cache-read tokens billed at 10% (nullable)
-- cacheCreationTokens INT Anthropic cache-write tokens billed at 125% (nullable)
-- totalTokens INT inputTokens + outputTokens (nullable if either is null)
-- duration FLOAT API call duration in seconds (nullable)
--
-- WINDOW
-- Rolling 90 days (createdAt > CURRENT_DATE - 90 days)
--
-- EXAMPLE QUERIES
-- -- Total spend by provider (last 90 days)
-- SELECT provider, SUM("costUsd") AS total_usd, COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- GROUP BY 1 ORDER BY total_usd DESC;
--
-- -- Spend by model
-- SELECT provider, model, SUM("costUsd") AS total_usd,
-- SUM("inputTokens") AS input_tokens,
-- SUM("outputTokens") AS output_tokens
-- FROM analytics.platform_cost_log
-- WHERE model IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_usd DESC;
--
-- -- Top 20 users by spend
-- SELECT "userId", email, SUM("costUsd") AS total_usd, COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- WHERE "userId" IS NOT NULL
-- GROUP BY 1, 2 ORDER BY total_usd DESC LIMIT 20;
--
-- -- Daily spend trend
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
-- SUM("costUsd") AS daily_usd,
-- COUNT(*) AS calls
-- FROM analytics.platform_cost_log
-- GROUP BY 1 ORDER BY 1;
--
-- -- Cache hit rate for Anthropic (cache reads vs total reads)
-- SELECT DATE_TRUNC('day', "createdAt") AS day,
-- SUM("cacheReadTokens")::float /
-- NULLIF(SUM("inputTokens" + COALESCE("cacheReadTokens", 0)), 0) AS cache_hit_rate
-- FROM analytics.platform_cost_log
-- WHERE provider = 'anthropic'
-- GROUP BY 1 ORDER BY 1;
-- =============================================================
SELECT
p."id" AS id,
p."createdAt" AS createdAt,
p."userId" AS userId,
u."email" AS email,
p."graphExecId" AS graphExecId,
p."nodeExecId" AS nodeExecId,
p."blockName" AS blockName,
p."provider" AS provider,
p."model" AS model,
p."trackingType" AS trackingType,
p."costMicrodollars" AS costMicrodollars,
p."costMicrodollars"::float / 1000000.0 AS costUsd,
p."inputTokens" AS inputTokens,
p."outputTokens" AS outputTokens,
p."cacheReadTokens" AS cacheReadTokens,
p."cacheCreationTokens" AS cacheCreationTokens,
CASE
WHEN p."inputTokens" IS NOT NULL AND p."outputTokens" IS NOT NULL
THEN p."inputTokens" + p."outputTokens"
ELSE NULL
END AS totalTokens,
p."duration" AS duration
FROM platform."PlatformCostLog" p
LEFT JOIN platform."User" u ON u."id" = p."userId"
WHERE p."createdAt" > CURRENT_DATE - INTERVAL '90 days'

View File

@@ -887,6 +887,21 @@ async def llm_call(
provider = llm_model.metadata.provider
context_window = llm_model.context_window
# Transparent OpenRouter routing for Anthropic models: when an OpenRouter API key
# is configured, route direct-Anthropic models through OpenRouter instead. This
# gives us the x-total-cost header for free, so provider_cost is always populated
# without manual token-rate arithmetic.
or_key = settings.secrets.open_router_api_key
or_model_id: str | None = None
if provider == "anthropic" and or_key:
provider = "open_router"
credentials = APIKeyCredentials(
provider=ProviderName.OPEN_ROUTER,
title="OpenRouter (auto)",
api_key=SecretStr(or_key),
)
or_model_id = f"anthropic/{llm_model.value}"
if compress_prompt_to_fit:
result = await compress_context(
messages=prompt,
@@ -1134,7 +1149,7 @@ async def llm_call(
"HTTP-Referer": "https://agpt.co",
"X-Title": "AutoGPT",
},
model=llm_model.value,
model=or_model_id or llm_model.value,
messages=prompt, # type: ignore
max_tokens=max_tokens,
tools=tools_param, # type: ignore

View File

@@ -77,7 +77,11 @@ class TestLLMStatsTracking:
mock_response.usage = mock_usage
mock_response.stop_reason = "end_turn"
with patch("anthropic.AsyncAnthropic") as mock_anthropic:
with (
patch("anthropic.AsyncAnthropic") as mock_anthropic,
patch("backend.blocks.llm.settings") as mock_settings,
):
mock_settings.secrets.open_router_api_key = ""
mock_client = AsyncMock()
mock_anthropic.return_value = mock_client
mock_client.messages.create = AsyncMock(return_value=mock_response)
@@ -96,6 +100,56 @@ class TestLLMStatsTracking:
assert response.cache_creation_tokens == 50
assert response.response == "Test anthropic response"
@pytest.mark.asyncio
async def test_anthropic_routes_through_openrouter_when_key_present(self):
"""When open_router_api_key is set, Anthropic models route via OpenRouter."""
from pydantic import SecretStr
import backend.blocks.llm as llm
from backend.data.model import APIKeyCredentials
anthropic_creds = APIKeyCredentials(
id="test-anthropic-id",
provider="anthropic",
api_key=SecretStr("mock-anthropic-key"),
title="Mock Anthropic key",
)
mock_choice = MagicMock()
mock_choice.message.content = "routed response"
mock_choice.message.tool_calls = None
mock_usage = MagicMock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 5
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = mock_usage
mock_create = AsyncMock(return_value=mock_response)
with (
patch("openai.AsyncOpenAI") as mock_openai,
patch("backend.blocks.llm.settings") as mock_settings,
):
mock_settings.secrets.open_router_api_key = "sk-or-test-key"
mock_client = MagicMock()
mock_openai.return_value = mock_client
mock_client.chat.completions.create = mock_create
await llm.llm_call(
credentials=anthropic_creds,
llm_model=llm.LlmModel.CLAUDE_3_HAIKU,
prompt=[{"role": "user", "content": "Hello"}],
max_tokens=100,
)
# Verify OpenAI client was used (not Anthropic SDK) and model was prefixed
mock_openai.assert_called_once()
call_kwargs = mock_create.call_args.kwargs
assert call_kwargs["model"] == "anthropic/claude-3-haiku-20240307"
@pytest.mark.asyncio
async def test_ai_structured_response_block_tracks_stats(self):
"""Test that AIStructuredResponseGeneratorBlock correctly tracks stats."""

View File

@@ -4,10 +4,10 @@ from datetime import datetime, timedelta, timezone
from typing import Any
from prisma.models import PlatformCostLog as PrismaLog
from prisma.types import PlatformCostLogCreateInput
from prisma.models import User as PrismaUser
from prisma.types import PlatformCostLogCreateInput, PlatformCostLogWhereInput
from pydantic import BaseModel
from backend.data.db import query_raw_with_schema
from backend.util.cache import cached
from backend.util.json import SafeJson
@@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
MICRODOLLARS_PER_USD = 1_000_000
# Dashboard query limits — keep in sync with the SQL queries below
# Dashboard query limits
MAX_PROVIDER_ROWS = 500
MAX_USER_ROWS = 100
@@ -169,53 +169,61 @@ class PlatformCostDashboard(BaseModel):
total_users: int
def _build_where(
def _si(row: dict, field: str) -> int:
"""Extract an integer from a Prisma group_by _sum dict.
Prisma Python serialises BigInt/Int aggregate sums as strings; coerce to int.
"""
return int((row.get("_sum") or {}).get(field) or 0)
def _sf(row: dict, field: str) -> float:
"""Extract a float from a Prisma group_by _sum dict."""
return float((row.get("_sum") or {}).get(field) or 0.0)
def _ca(row: dict) -> int:
"""Extract _count._all from a Prisma group_by row."""
c = row.get("_count") or {}
return int(c.get("_all") or 0) if isinstance(c, dict) else int(c or 0)
def _build_prisma_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list[Any]]:
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
) -> PlatformCostLogWhereInput:
"""Build a Prisma WhereInput for PlatformCostLog filters."""
where: PlatformCostLogWhereInput = {}
if start and end:
where["createdAt"] = {"gte": start, "lte": end}
elif start:
where["createdAt"] = {"gte": start}
elif end:
where["createdAt"] = {"lte": end}
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
# Provider names are normalized to lowercase at write time so a plain
# equality check is sufficient and the (provider, createdAt) index is used.
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
if model:
clauses.append(f'{prefix}"model" = ${idx}')
params.append(model)
idx += 1
if block_name:
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
if tracking_type:
clauses.append(f'{prefix}"trackingType" = ${idx}')
params.append(tracking_type)
idx += 1
where["provider"] = provider.lower()
return (" AND ".join(clauses) if clauses else "TRUE", params)
if user_id:
where["userId"] = user_id
if model:
where["model"] = model
if block_name:
# Case-insensitive match — mirrors the original LOWER() SQL filter.
where["blockName"] = {"equals": block_name, "mode": "insensitive"}
if tracking_type:
where["trackingType"] = tracking_type
return where
@cached(ttl_seconds=30)
@@ -241,110 +249,107 @@ async def get_platform_cost_dashboard(
"""
if start is None:
start = datetime.now(timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_p, params_p = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
by_provider_rows, by_user_rows, total_user_rows, total_agg_rows = (
sum_fields = {
"costMicrodollars": True,
"inputTokens": True,
"outputTokens": True,
"cacheReadTokens": True,
"cacheCreationTokens": True,
"duration": True,
"trackingAmount": True,
}
# Run all four aggregation queries in parallel.
by_provider_groups, by_user_groups, total_user_groups, total_agg_groups = (
await asyncio.gather(
query_raw_with_schema(
f"""
SELECT
p."provider",
p."trackingType" AS tracking_type,
p."model",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COALESCE(SUM(p."cacheReadTokens"), 0)::bigint AS total_cache_read_tokens,
COALESCE(SUM(p."cacheCreationTokens"), 0)::bigint AS total_cache_creation_tokens,
COALESCE(SUM(p."duration"), 0)::float AS total_duration,
COALESCE(SUM(p."trackingAmount"), 0)::float AS total_tracking_amount,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider", p."trackingType", p."model"
ORDER BY total_cost DESC
LIMIT {MAX_PROVIDER_ROWS}
""",
*params_p,
# (provider, trackingType, model) aggregation — no ORDER BY in ORM;
# sort by total cost descending in Python after fetch.
PrismaLog.prisma().group_by(
by=["provider", "trackingType", "model"],
where=where,
sum=sum_fields,
count=True,
),
query_raw_with_schema(
f"""
SELECT
p."userId" AS user_id,
u."email",
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_p}
GROUP BY p."userId", u."email"
ORDER BY total_cost DESC
LIMIT {MAX_USER_ROWS}
""",
*params_p,
# userId aggregation — emails fetched separately below.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
sum=sum_fields,
count=True,
),
query_raw_with_schema(
f"""
SELECT COUNT(DISTINCT p."userId")::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
# Distinct user count: group by userId, count groups.
PrismaLog.prisma().group_by(
by=["userId"],
where=where,
count=True,
),
# Separate aggregate query so dashboard totals are never derived
# from the capped by_provider_rows list. With model-level grouping,
# MAX_PROVIDER_ROWS is hit more easily; summing the capped rows
# would silently undercount once >500 (provider, type, model) exist.
query_raw_with_schema(
f"""
SELECT
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
""",
*params_p,
# Total aggregate: group by provider (no limit) to sum across all
# matching rows. Summed in Python to get grand totals.
PrismaLog.prisma().group_by(
by=["provider"],
where=where,
sum={"costMicrodollars": True},
count=True,
),
)
)
# Use the exact COUNT(DISTINCT userId) so total_users is not capped at
# MAX_USER_ROWS (which would silently report 100 for >100 active users).
total_users = int(total_user_rows[0]["cnt"]) if total_user_rows else 0
total_cost = int(total_agg_rows[0]["total_cost"]) if total_agg_rows else 0
total_requests = int(total_agg_rows[0]["request_count"]) if total_agg_rows else 0
# Sort by_provider by total cost descending and cap at MAX_PROVIDER_ROWS.
by_provider_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
by_provider_groups = by_provider_groups[:MAX_PROVIDER_ROWS]
# Sort by_user by total cost descending and cap at MAX_USER_ROWS.
by_user_groups.sort(key=lambda r: _si(r, "costMicrodollars"), reverse=True)
by_user_groups = by_user_groups[:MAX_USER_ROWS]
# Batch-fetch emails for the users in by_user.
user_ids = [r["userId"] for r in by_user_groups if r.get("userId") is not None]
email_by_user_id: dict[str, str | None] = {}
if user_ids:
users = await PrismaUser.prisma().find_many(
where={"id": {"in": user_ids}},
)
email_by_user_id = {u.id: u.email for u in users}
# Total distinct users — exclude the NULL-userId group (deleted users).
total_users = len([g for g in total_user_groups if g.get("userId") is not None])
# Grand totals — sum across all provider groups (no LIMIT applied above).
total_cost = sum(_si(r, "costMicrodollars") for r in total_agg_groups)
total_requests = sum(_ca(r) for r in total_agg_groups)
return PlatformCostDashboard(
by_provider=[
ProviderCostSummary(
provider=r["provider"],
tracking_type=r.get("tracking_type"),
tracking_type=r.get("trackingType"),
model=r.get("model"),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
total_cache_read_tokens=r.get("total_cache_read_tokens", 0),
total_cache_creation_tokens=r.get("total_cache_creation_tokens", 0),
total_duration_seconds=r.get("total_duration", 0.0),
total_tracking_amount=r.get("total_tracking_amount", 0.0),
request_count=r["request_count"],
total_cost_microdollars=_si(r, "costMicrodollars"),
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
total_cache_read_tokens=_si(r, "cacheReadTokens"),
total_cache_creation_tokens=_si(r, "cacheCreationTokens"),
total_duration_seconds=_sf(r, "duration"),
total_tracking_amount=_sf(r, "trackingAmount"),
request_count=_ca(r),
)
for r in by_provider_rows
for r in by_provider_groups
],
by_user=[
UserCostSummary(
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
total_cost_microdollars=r["total_cost"],
total_input_tokens=r["total_input_tokens"],
total_output_tokens=r["total_output_tokens"],
request_count=r["request_count"],
user_id=r.get("userId"),
email=_mask_email(email_by_user_id.get(r.get("userId") or "")),
total_cost_microdollars=_si(r, "costMicrodollars"),
total_input_tokens=_si(r, "inputTokens"),
total_output_tokens=_si(r, "outputTokens"),
request_count=_ca(r),
)
for r in by_user_rows
for r in by_user_groups
],
total_cost_microdollars=total_cost,
total_requests=total_requests,
@@ -365,73 +370,41 @@ async def get_platform_cost_logs(
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
)
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
offset = (page - 1) * page_size
limit_idx = len(params) + 1
offset_idx = len(params) + 2
count_rows, rows = await asyncio.gather(
query_raw_with_schema(
f"""
SELECT COUNT(*)::bigint AS cnt
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_sql}
""",
*params,
),
query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."cacheReadTokens" AS cache_read_tokens,
p."cacheCreationTokens" AS cache_creation_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx} OFFSET ${offset_idx}
""",
*params,
page_size,
offset,
total, rows = await asyncio.gather(
PrismaLog.prisma().count(where=where),
PrismaLog.prisma().find_many(
where=where,
include={"User": True},
order=[{"createdAt": "desc"}, {"id": "desc"}],
take=page_size,
skip=offset,
),
)
total = count_rows[0]["cnt"] if count_rows else 0
logs = [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
cache_read_tokens=r.get("cache_read_tokens"),
cache_creation_tokens=r.get("cache_creation_tokens"),
duration=r.get("duration"),
model=r.get("model"),
id=r.id,
created_at=r.createdAt,
user_id=r.userId,
email=_mask_email(r.User.email if r.User else None),
graph_exec_id=r.graphExecId,
node_exec_id=r.nodeExecId,
block_name=r.blockName or "",
provider=r.provider,
tracking_type=r.trackingType,
cost_microdollars=r.costMicrodollars,
input_tokens=r.inputTokens,
output_tokens=r.outputTokens,
cache_read_tokens=getattr(r, "cacheReadTokens", None),
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
duration=r.duration,
model=r.model,
)
for r in rows
]
@@ -457,38 +430,16 @@ async def get_platform_cost_logs_for_export(
"""
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(
start, end, provider, user_id, "p", model, block_name, tracking_type
)
limit_idx = len(params) + 1
rows = await query_raw_with_schema(
f"""
SELECT
p."id",
p."createdAt" AS created_at,
p."userId" AS user_id,
u."email",
p."graphExecId" AS graph_exec_id,
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."trackingType" AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,
p."cacheReadTokens" AS cache_read_tokens,
p."cacheCreationTokens" AS cache_creation_tokens,
p."duration",
p."model"
FROM {{schema_prefix}}"PlatformCostLog" p
LEFT JOIN {{schema_prefix}}"User" u ON u."id" = p."userId"
WHERE {where_sql}
ORDER BY p."createdAt" DESC, p."id" DESC
LIMIT ${limit_idx}
""",
*params,
EXPORT_MAX_ROWS + 1,
where = _build_prisma_where(
start, end, provider, user_id, model, block_name, tracking_type
)
rows = await PrismaLog.prisma().find_many(
where=where,
include={"User": True},
order=[{"createdAt": "desc"}, {"id": "desc"}],
take=EXPORT_MAX_ROWS + 1,
)
truncated = len(rows) > EXPORT_MAX_ROWS
@@ -496,22 +447,80 @@ async def get_platform_cost_logs_for_export(
return [
CostLogRow(
id=r["id"],
created_at=r["created_at"],
user_id=r.get("user_id"),
email=_mask_email(r.get("email")),
graph_exec_id=r.get("graph_exec_id"),
node_exec_id=r.get("node_exec_id"),
block_name=r["block_name"],
provider=r["provider"],
tracking_type=r.get("tracking_type"),
cost_microdollars=r.get("cost_microdollars"),
input_tokens=r.get("input_tokens"),
output_tokens=r.get("output_tokens"),
cache_read_tokens=r.get("cache_read_tokens"),
cache_creation_tokens=r.get("cache_creation_tokens"),
duration=r.get("duration"),
model=r.get("model"),
id=r.id,
created_at=r.createdAt,
user_id=r.userId,
email=_mask_email(r.User.email if r.User else None),
graph_exec_id=r.graphExecId,
node_exec_id=r.nodeExecId,
block_name=r.blockName or "",
provider=r.provider,
tracking_type=r.trackingType,
cost_microdollars=r.costMicrodollars,
input_tokens=r.inputTokens,
output_tokens=r.outputTokens,
cache_read_tokens=getattr(r, "cacheReadTokens", None),
cache_creation_tokens=getattr(r, "cacheCreationTokens", None),
duration=r.duration,
model=r.model,
)
for r in rows
], truncated
# ---------------------------------------------------------------------------
# Helpers kept for backward-compatibility with existing tests.
# New code should not use these — use _build_prisma_where instead.
# ---------------------------------------------------------------------------
def _build_where(
start: datetime | None,
end: datetime | None,
provider: str | None,
user_id: str | None,
table_alias: str = "",
model: str | None = None,
block_name: str | None = None,
tracking_type: str | None = None,
) -> tuple[str, list[Any]]:
"""Legacy SQL WHERE builder — retained so existing unit tests still pass.
Only used by tests that verify the SQL-string generation logic. All
production code uses _build_prisma_where instead.
"""
prefix = f"{table_alias}." if table_alias else ""
clauses: list[str] = []
params: list[Any] = []
idx = 1
if start:
clauses.append(f'{prefix}"createdAt" >= ${idx}::timestamptz')
params.append(start)
idx += 1
if end:
clauses.append(f'{prefix}"createdAt" <= ${idx}::timestamptz')
params.append(end)
idx += 1
if provider:
clauses.append(f'{prefix}"provider" = ${idx}')
params.append(provider.lower())
idx += 1
if user_id:
clauses.append(f'{prefix}"userId" = ${idx}')
params.append(user_id)
idx += 1
if model:
clauses.append(f'{prefix}"model" = ${idx}')
params.append(model)
idx += 1
if block_name:
clauses.append(f'LOWER({prefix}"blockName") = LOWER(${idx})')
params.append(block_name)
idx += 1
if tracking_type:
clauses.append(f'{prefix}"trackingType" = ${idx}')
params.append(tracking_type)
idx += 1
return (" AND ".join(clauses) if clauses else "TRUE", params)

View File

@@ -1,7 +1,7 @@
"""Unit tests for helpers and async functions in platform_cost module."""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from prisma import Json
@@ -224,6 +224,41 @@ class TestLogPlatformCostSafe:
mock_create.assert_awaited_once()
def _make_group_by_row(
provider: str = "openai",
tracking_type: str | None = "tokens",
model: str | None = None,
cost: int = 5000,
input_tokens: int = 1000,
output_tokens: int = 500,
cache_read_tokens: int = 0,
cache_creation_tokens: int = 0,
duration: float = 10.5,
tracking_amount: float = 0.0,
count: int = 3,
user_id: str | None = None,
) -> dict:
row: dict = {
"_sum": {
"costMicrodollars": cost,
"inputTokens": input_tokens,
"outputTokens": output_tokens,
"cacheReadTokens": cache_read_tokens,
"cacheCreationTokens": cache_creation_tokens,
"duration": duration,
"trackingAmount": tracking_amount,
},
"_count": {"_all": count},
}
if user_id is not None:
row["userId"] = user_id
else:
row["provider"] = provider
row["trackingType"] = tracking_type
row["model"] = model
return row
class TestGetPlatformCostDashboard:
def setup_method(self):
# @cached stores results in-process; clear between tests to avoid bleed.
@@ -231,35 +266,44 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_returns_dashboard_with_data(self):
provider_rows = [
{
"provider": "openai",
"tracking_type": "tokens",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"total_duration": 10.5,
"request_count": 3,
}
]
user_rows = [
{
"user_id": "u1",
"email": "a@b.com",
"total_cost": 5000,
"total_input_tokens": 1000,
"total_output_tokens": 500,
"request_count": 3,
}
]
# Dashboard runs 4 queries: by_provider, by_user, COUNT(DISTINCT userId),
# and a separate total aggregate (total_cost + request_count with no LIMIT).
agg_rows = [{"total_cost": 5000, "request_count": 3}]
mock_query = AsyncMock(
side_effect=[provider_rows, user_rows, [{"cnt": 1}], agg_rows]
provider_row = _make_group_by_row(
provider="openai",
tracking_type="tokens",
cost=5000,
input_tokens=1000,
output_tokens=500,
duration=10.5,
count=3,
)
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
user_row = _make_group_by_row(user_id="u1", cost=5000, count=3)
mock_user = MagicMock()
mock_user.id = "u1"
mock_user.email = "a@b.com"
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[{"userId": "u1"}], # distinct users
[provider_row], # total agg
]
)
mock_actions.find_many = AsyncMock(return_value=[mock_user])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 5000
assert dashboard.total_requests == 3
assert dashboard.total_users == 1
@@ -271,10 +315,67 @@ class TestGetPlatformCostDashboard:
assert dashboard.by_user[0].email == "a***@b.com"
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_query = AsyncMock(side_effect=[[], [], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
async def test_cache_tokens_aggregated_not_hardcoded(self):
"""cache_read_tokens and cache_creation_tokens must be read from the
DB aggregation, not hardcoded to 0 (regression guard for Sentry report)."""
provider_row = _make_group_by_row(
provider="anthropic",
tracking_type="tokens",
cost=1000,
input_tokens=800,
output_tokens=200,
cache_read_tokens=400,
cache_creation_tokens=100,
count=1,
)
user_row = _make_group_by_row(user_id="u2", cost=1000, count=1)
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(
side_effect=[
[provider_row], # by_provider
[user_row], # by_user
[{"userId": "u2"}], # distinct users
[provider_row], # total agg
]
)
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert len(dashboard.by_provider) == 1
row = dashboard.by_provider[0]
assert row.total_cache_read_tokens == 400
assert row.total_cache_creation_tokens == 100
@pytest.mark.asyncio
async def test_returns_empty_dashboard(self):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
dashboard = await get_platform_cost_dashboard()
assert dashboard.total_cost_microdollars == 0
assert dashboard.total_requests == 0
assert dashboard.total_users == 0
@@ -284,160 +385,228 @@ class TestGetPlatformCostDashboard:
@pytest.mark.asyncio
async def test_passes_filters_to_queries(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[], [], [], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.group_by = AsyncMock(side_effect=[[], [], [], []])
mock_actions.find_many = AsyncMock(return_value=[])
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch(
"backend.data.platform_cost.PrismaUser.prisma",
return_value=mock_actions,
),
):
await get_platform_cost_dashboard(
start=start, provider="openai", user_id="u1"
)
assert mock_query.await_count == 4
first_call_sql = mock_query.call_args_list[0][0][0]
assert "createdAt" in first_call_sql
# group_by called 4 times (by_provider, by_user, distinct users, totals)
assert mock_actions.group_by.await_count == 4
# The where dict passed to the first call should include createdAt
first_call_kwargs = mock_actions.group_by.call_args_list[0][1]
assert "createdAt" in first_call_kwargs.get("where", {})
def _make_prisma_log_row(
i: int = 0,
user_email: str | None = None,
) -> MagicMock:
row = MagicMock()
row.id = f"log-{i}"
row.createdAt = datetime(2026, 3, 1, tzinfo=timezone.utc)
row.userId = "u1"
row.graphExecId = None
row.nodeExecId = None
row.blockName = "TestBlock"
row.provider = "openai"
row.trackingType = "tokens"
row.costMicrodollars = 1000
row.inputTokens = 10
row.outputTokens = 5
row.duration = 0.5
row.model = "gpt-4"
# cacheReadTokens / cacheCreationTokens may not exist on older Prisma clients
row.configure_mock(**{"cacheReadTokens": None, "cacheCreationTokens": None})
if user_email is not None:
row.User = MagicMock()
row.User.email = user_email
else:
row.User = None
return row
class TestGetPlatformCostLogs:
@pytest.mark.asyncio
async def test_returns_logs_and_total(self):
count_rows = [{"cnt": 1}]
log_rows = [
{
"id": "log-1",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": "a@b.com",
"graph_exec_id": "g1",
"node_exec_id": "n1",
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 5000,
"input_tokens": 100,
"output_tokens": 50,
"duration": 1.5,
"model": "gpt-4",
}
]
mock_query = AsyncMock(side_effect=[count_rows, log_rows])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0, user_email="a@b.com")
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=1)
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(page=1, page_size=10)
assert total == 1
assert len(logs) == 1
assert logs[0].id == "log-1"
assert logs[0].id == "log-0"
assert logs[0].provider == "openai"
assert logs[0].model == "gpt-4"
@pytest.mark.asyncio
async def test_returns_empty_when_no_data(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=0)
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs()
assert total == 0
assert logs == []
@pytest.mark.asyncio
async def test_pagination_offset(self):
mock_query = AsyncMock(side_effect=[[{"cnt": 100}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
second_call_args = mock_query.call_args_list[1][0]
assert 25 in second_call_args # page_size
assert 50 in second_call_args # offset = (3-1) * 25
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=100)
mock_actions.find_many = AsyncMock(return_value=[])
@pytest.mark.asyncio
async def test_empty_count_returns_zero(self):
mock_query = AsyncMock(side_effect=[[], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
logs, total = await get_platform_cost_logs()
assert total == 0
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(page=3, page_size=25)
assert total == 100
find_many_call = mock_actions.find_many.call_args[1]
assert find_many_call["take"] == 25
assert find_many_call["skip"] == 50 # (3-1) * 25
@pytest.mark.asyncio
async def test_explicit_start_skips_default(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(side_effect=[[{"cnt": 0}], []])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.count = AsyncMock(return_value=0)
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, total = await get_platform_cost_logs(start=start)
assert total == 0
def _make_log_row(i: int = 0) -> dict:
return {
"id": f"log-{i}",
"created_at": datetime(2026, 3, 1, tzinfo=timezone.utc),
"user_id": "u1",
"email": None,
"graph_exec_id": None,
"node_exec_id": None,
"block_name": "TestBlock",
"provider": "openai",
"tracking_type": "tokens",
"cost_microdollars": 1000,
"input_tokens": 10,
"output_tokens": 5,
"duration": 0.5,
"model": "gpt-4",
"cache_read_tokens": None,
"cache_creation_tokens": None,
}
where = mock_actions.count.call_args[1]["where"]
# start provided — should appear in the where filter
assert "createdAt" in where
class TestGetPlatformCostLogsForExport:
@pytest.mark.asyncio
async def test_returns_logs_not_truncated(self):
rows = [_make_log_row(0)]
mock_query = AsyncMock(return_value=rows)
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0)
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export()
assert len(logs) == 1
assert truncated is False
assert logs[0].id == "log-0"
@pytest.mark.asyncio
async def test_returns_empty_not_truncated(self):
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export()
assert logs == []
assert truncated is False
@pytest.mark.asyncio
async def test_truncates_at_export_max_rows(self):
rows = [_make_log_row(i) for i in range(3)]
mock_query = AsyncMock(return_value=rows)
with patch(
"backend.data.platform_cost.query_raw_with_schema", new=mock_query
), patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2):
rows = [_make_prisma_log_row(i) for i in range(3)]
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=rows)
with (
patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
),
patch("backend.data.platform_cost.EXPORT_MAX_ROWS", 2),
):
logs, truncated = await get_platform_cost_logs_for_export()
assert len(logs) == 2
assert truncated is True
@pytest.mark.asyncio
async def test_passes_model_block_tracking_filters(self):
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
await get_platform_cost_logs_for_export(
model="gpt-4", block_name="LLMBlock", tracking_type="tokens"
)
call_args = mock_query.call_args[0]
assert "gpt-4" in call_args
assert "LLMBlock" in call_args
assert "tokens" in call_args
where = mock_actions.find_many.call_args[1]["where"]
assert where.get("model") == "gpt-4"
assert where.get("trackingType") == "tokens"
# blockName uses a dict filter for case-insensitive match
assert "blockName" in where
@pytest.mark.asyncio
async def test_maps_cache_tokens(self):
row = _make_log_row(0)
row["cache_read_tokens"] = 50
row["cache_creation_tokens"] = 25
mock_query = AsyncMock(return_value=[row])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
row = _make_prisma_log_row(0)
row.configure_mock(**{"cacheReadTokens": 50, "cacheCreationTokens": 25})
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[row])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, _ = await get_platform_cost_logs_for_export()
assert logs[0].cache_read_tokens == 50
assert logs[0].cache_creation_tokens == 25
@pytest.mark.asyncio
async def test_explicit_start_skips_default(self):
start = datetime(2026, 1, 1, tzinfo=timezone.utc)
mock_query = AsyncMock(return_value=[])
with patch("backend.data.platform_cost.query_raw_with_schema", new=mock_query):
mock_actions = MagicMock()
mock_actions.find_many = AsyncMock(return_value=[])
with patch(
"backend.data.platform_cost.PrismaLog.prisma",
return_value=mock_actions,
):
logs, truncated = await get_platform_cost_logs_for_export(start=start)
assert logs == []
assert truncated is False
where = mock_actions.find_many.call_args[1]["where"]
assert "createdAt" in where