fix(platform): address 4 review blockers on cost tracking

- Fire-and-forget cost logging via asyncio.create_task() instead of await
  to avoid blocking executor and copilot streaming paths on DB INSERT
- Add trackingType column to PlatformCostLog schema, migration, and INSERT;
  update dashboard/logs queries to use COALESCE(column, JSONB) for backward
  compat and index-friendly GROUP BY
- Admin auth test now explicitly mocks get_jwt_payload to raise 401 instead
  of relying on bare FastAPI app behavior
- Blocker 3 (nullable user_id) was already addressed in prior commit
This commit is contained in:
Zamil Majdy
2026-04-03 22:43:57 +02:00
parent b00e16b438
commit 8d22653810
8 changed files with 91 additions and 45 deletions

View File

@@ -130,6 +130,15 @@ def test_get_logs_with_pagination(
def test_get_dashboard_requires_admin() -> None:
app.dependency_overrides.clear()
response = client.get("/admin/platform_costs/dashboard")
assert response.status_code in (401, 403)
import fastapi
from fastapi import HTTPException
def reject_jwt(request: fastapi.Request):
raise HTTPException(status_code=401, detail="Not authenticated")
app.dependency_overrides[get_jwt_payload] = reject_jwt
try:
response = client.get("/admin/platform_costs/dashboard")
assert response.status_code == 401
finally:
app.dependency_overrides.clear()

View File

@@ -9,6 +9,7 @@ Both the baseline (OpenRouter) and SDK (Anthropic) service layers need to:
This module extracts that common logic so both paths stay in sync.
"""
import asyncio
import logging
from backend.data.platform_cost import PlatformCostEntry, log_platform_cost_safe
@@ -122,25 +123,28 @@ async def persist_and_record_usage(
tracking_type = "tokens"
tracking_amount = total_tokens
await log_platform_cost_safe(
PlatformCostEntry(
user_id=user_id,
graph_exec_id=session_id,
block_id="copilot",
block_name=f"copilot:{log_prefix.strip(' []')}".rstrip(":"),
provider=provider,
credential_id="copilot_system",
cost_microdollars=cost_microdollars,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
model=model,
metadata={
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
"cache_read_tokens": cache_read_tokens,
"cache_creation_tokens": cache_creation_tokens,
"source": "copilot",
},
asyncio.create_task(
log_platform_cost_safe(
PlatformCostEntry(
user_id=user_id,
graph_exec_id=session_id,
block_id="copilot",
block_name=f"copilot:{log_prefix.strip(' []')}".rstrip(":"),
provider=provider,
credential_id="copilot_system",
cost_microdollars=cost_microdollars,
input_tokens=prompt_tokens,
output_tokens=completion_tokens,
model=model,
tracking_type=tracking_type,
metadata={
"tracking_type": tracking_type,
"tracking_amount": tracking_amount,
"cache_read_tokens": cache_read_tokens,
"cache_creation_tokens": cache_creation_tokens,
"source": "copilot",
},
)
)
)

View File

@@ -4,6 +4,7 @@ Covers both the baseline (prompt+completion only) and SDK (with cache breakdown)
calling conventions, session persistence, and rate-limit recording.
"""
import asyncio
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
@@ -311,6 +312,7 @@ class TestPlatformCostLogging:
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cost"
@@ -319,6 +321,7 @@ class TestPlatformCostLogging:
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 200
assert entry.output_tokens == 100
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["tracking_amount"] == 0.005
assert entry.block_name == "copilot:SDK"
@@ -345,9 +348,11 @@ class TestPlatformCostLogging:
completion_tokens=50,
log_prefix="[Baseline]",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 150
assert entry.graph_exec_id is None
@@ -373,6 +378,7 @@ class TestPlatformCostLogging:
prompt_tokens=100,
completion_tokens=50,
)
await asyncio.sleep(0)
mock_log.assert_not_awaited()
@pytest.mark.asyncio
@@ -396,6 +402,7 @@ class TestPlatformCostLogging:
completion_tokens=50,
cost_usd="not-a-number",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars is None
@@ -422,6 +429,7 @@ class TestPlatformCostLogging:
completion_tokens=50,
cost_usd="0.01",
)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars == 10_000
@@ -448,6 +456,7 @@ class TestPlatformCostLogging:
completion_tokens=5,
log_prefix="",
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.block_name == "copilot"
@@ -473,6 +482,7 @@ class TestPlatformCostLogging:
cache_read_tokens=5000,
cache_creation_tokens=300,
)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.metadata["cache_read_tokens"] == 5000
assert entry.metadata["cache_creation_tokens"] == 300

View File

@@ -27,6 +27,7 @@ class PlatformCostEntry(BaseModel):
data_size: int | None = None
duration: float | None = None
model: str | None = None
tracking_type: str | None = None
metadata: dict[str, Any] | None = None
@@ -37,10 +38,10 @@ async def log_platform_cost(entry: PlatformCostEntry) -> None:
("id", "createdAt", "userId", "graphExecId", "nodeExecId",
"graphId", "nodeId", "blockId", "blockName", "provider",
"credentialId", "costMicrodollars", "inputTokens", "outputTokens",
"dataSize", "duration", "model", "metadata")
"dataSize", "duration", "model", "trackingType", "metadata")
VALUES (
gen_random_uuid(), NOW(), $1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14, $15, $16::jsonb
$10, $11, $12, $13, $14, $15, $16, $17::jsonb
)
""",
entry.user_id,
@@ -58,6 +59,7 @@ async def log_platform_cost(entry: PlatformCostEntry) -> None:
entry.data_size,
entry.duration,
entry.model,
entry.tracking_type,
_json_or_none(entry.metadata),
)
@@ -170,7 +172,8 @@ async def get_platform_cost_dashboard(
f"""
SELECT
p."provider",
p."metadata"->>'tracking_type' AS tracking_type,
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
AS tracking_type,
COALESCE(SUM(p."costMicrodollars"), 0)::bigint AS total_cost,
COALESCE(SUM(p."inputTokens"), 0)::bigint AS total_input_tokens,
COALESCE(SUM(p."outputTokens"), 0)::bigint AS total_output_tokens,
@@ -178,7 +181,8 @@ async def get_platform_cost_dashboard(
COUNT(*)::bigint AS request_count
FROM {{schema_prefix}}"PlatformCostLog" p
WHERE {where_p}
GROUP BY p."provider", p."metadata"->>'tracking_type'
GROUP BY p."provider",
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
ORDER BY total_cost DESC
""",
*params_p,
@@ -279,7 +283,8 @@ async def get_platform_cost_logs(
p."nodeExecId" AS node_exec_id,
p."blockName" AS block_name,
p."provider",
p."metadata"->>'tracking_type' AS tracking_type,
COALESCE(p."trackingType", p."metadata"->>'tracking_type')
AS tracking_type,
p."costMicrodollars" AS cost_microdollars,
p."inputTokens" AS input_tokens,
p."outputTokens" AS output_tokens,

View File

@@ -1,5 +1,6 @@
"""Helpers for platform cost tracking on system-credential block executions."""
import asyncio
import logging
from typing import Any, cast
@@ -109,24 +110,27 @@ async def log_system_credential_cost(
if stats.provider_cost is not None:
meta["provider_cost_usd"] = stats.provider_cost
await log_platform_cost_safe(
PlatformCostEntry(
user_id=node_exec.user_id,
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
graph_id=node_exec.graph_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block_name=block.name,
provider=provider_name,
credential_id=cred_id,
cost_microdollars=cost_microdollars,
input_tokens=stats.input_token_count or None,
output_tokens=stats.output_token_count or None,
data_size=stats.output_size or None,
duration=stats.walltime or None,
model=model_name,
metadata=meta or None,
asyncio.create_task(
log_platform_cost_safe(
PlatformCostEntry(
user_id=node_exec.user_id,
graph_exec_id=node_exec.graph_exec_id,
node_exec_id=node_exec.node_exec_id,
graph_id=node_exec.graph_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block_name=block.name,
provider=provider_name,
credential_id=cred_id,
cost_microdollars=cost_microdollars,
input_tokens=stats.input_token_count or None,
output_tokens=stats.output_token_count or None,
data_size=stats.output_size or None,
duration=stats.walltime or None,
model=model_name,
tracking_type=tracking_type,
metadata=meta or None,
)
)
)
return # One log per execution is enough

View File

@@ -1,5 +1,6 @@
"""Unit tests for resolve_tracking and log_system_credential_cost."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@@ -248,6 +249,7 @@ class TestLogSystemCredentialCost:
block = _make_block()
stats = NodeExecutionStats(input_token_count=500, output_token_count=200)
await log_system_credential_cost(node_exec, block, stats)
await asyncio.sleep(0)
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
@@ -257,6 +259,7 @@ class TestLogSystemCredentialCost:
assert entry.model == "gpt-4"
assert entry.input_tokens == 500
assert entry.output_tokens == 200
assert entry.tracking_type == "tokens"
assert entry.metadata["tracking_type"] == "tokens"
assert entry.metadata["tracking_amount"] == 700.0
assert entry.metadata["credit_cost"] == 10
@@ -284,9 +287,11 @@ class TestLogSystemCredentialCost:
block = _make_block()
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars == 1500
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["provider_cost_usd"] == 0.0015
@@ -319,6 +324,7 @@ class TestLogSystemCredentialCost:
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.model == "FakeModel.GPT4"
@@ -347,6 +353,7 @@ class TestLogSystemCredentialCost:
block = _make_block()
stats = NodeExecutionStats()
await log_system_credential_cost(node_exec, block, stats)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.model is None
@@ -376,6 +383,7 @@ class TestLogSystemCredentialCost:
# round() should give 1500, int() would give 1499
stats = NodeExecutionStats(provider_cost=0.0015)
await log_system_credential_cost(node_exec, block, stats)
await asyncio.sleep(0)
entry = mock_log.call_args[0][0]
assert entry.cost_microdollars == 1500

View File

@@ -17,6 +17,7 @@ CREATE TABLE "PlatformCostLog" (
"dataSize" INTEGER,
"duration" DOUBLE PRECISION,
"model" TEXT,
"trackingType" TEXT,
"metadata" JSONB,
CONSTRAINT "PlatformCostLog_pkey" PRIMARY KEY ("id")
@@ -34,5 +35,8 @@ CREATE INDEX "PlatformCostLog_createdAt_idx" ON "PlatformCostLog"("createdAt");
-- CreateIndex
CREATE INDEX "PlatformCostLog_graphExecId_idx" ON "PlatformCostLog"("graphExecId");
-- CreateIndex
CREATE INDEX "PlatformCostLog_provider_trackingType_idx" ON "PlatformCostLog"("provider", "trackingType");
-- AddForeignKey
ALTER TABLE "PlatformCostLog" ADD CONSTRAINT "PlatformCostLog_userId_fkey" FOREIGN KEY ("userId") REFERENCES "User"("id") ON DELETE SET NULL ON UPDATE CASCADE;

View File

@@ -845,12 +845,14 @@ model PlatformCostLog {
dataSize Int? // bytes
duration Float? // seconds
model String?
trackingType String? // e.g. "cost_usd", "tokens", "per_run"
metadata Json?
@@index([userId, createdAt])
@@index([provider, createdAt])
@@index([createdAt])
@@index([graphExecId])
@@index([provider, trackingType])
}
////////////////////////////////////////////////////////////