fix(backend): address 4 unresolved review threads on cost tracking

1. cost_tracking.py: replace shared _log_semaphore with per-loop dict
   (_log_semaphores + _get_log_semaphore()) — asyncio.Semaphore is not
   thread-safe and must not be shared across executor worker threads/loops

2. cost_tracking.py: only honor provider_cost_type when provider_cost is
   also present (not None); use tracking_amount (not raw stats.provider_cost)
   in usd_to_microdollars() to avoid unit mismatches

3. token_tracking.py: add semaphore to _schedule_cost_log (same pattern
   as cost_tracking.py) to bound concurrent DB inserts under load; fix
   forward-reference string in _pending_log_tasks type annotation

4. baseline/service.py: validate x-total-cost header with math.isfinite
   and max(0.0, cost) guard before accumulating — rejects nan/inf values
   that float() accepts but that should never reach the persistence path
This commit is contained in:
Zamil Majdy
2026-04-07 21:25:56 +07:00
parent b89321a688
commit 1c3fe1444e
3 changed files with 47 additions and 18 deletions

View File

@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
import asyncio
import base64
import logging
import math
import os
import re
import shutil
@@ -441,7 +442,9 @@ async def _baseline_llm_caller(
if raw_resp and hasattr(raw_resp, "headers"):
cost_header = raw_resp.headers.get("x-total-cost")
if cost_header:
state.cost_usd = (state.cost_usd or 0.0) + float(cost_header)
cost = float(cost_header)
if math.isfinite(cost):
state.cost_usd = (state.cost_usd or 0.0) + max(0.0, cost)
except (ValueError, AttributeError, UnboundLocalError):
pass

View File

@@ -22,22 +22,35 @@ from .rate_limit import record_token_usage
logger = logging.getLogger(__name__)
# Hold strong references to in-flight cost log tasks to prevent GC.
_pending_log_tasks: set["asyncio.Task[None]"] = set()
_pending_log_tasks: set[asyncio.Task[None]] = set()
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
async def _safe_log() -> None:
try:
await platform_cost_db().log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
async with _get_log_semaphore():
try:
await platform_cost_db().log_platform_cost(entry)
except Exception:
logger.exception(
"Failed to log platform cost for user=%s provider=%s block=%s",
entry.user_id,
entry.provider,
entry.block_name,
)
task = asyncio.create_task(_safe_log())
_pending_log_tasks.add(task)

View File

@@ -34,8 +34,19 @@ _WALLTIME_BILLED_PROVIDERS = frozenset(
# Hold strong references to in-flight log tasks so the event loop doesn't
# garbage-collect them mid-execution. Tasks remove themselves on completion.
_pending_log_tasks: set[asyncio.Task] = set()
# Bound concurrent DB inserts to avoid unbounded queue growth under load.
_log_semaphore = asyncio.Semaphore(50)
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
# shared across event loops running in different threads. Key by loop instance
# so each executor worker thread gets its own semaphore.
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
def _get_log_semaphore() -> asyncio.Semaphore:
loop = asyncio.get_running_loop()
sem = _log_semaphores.get(loop)
if sem is None:
sem = asyncio.Semaphore(50)
_log_semaphores[loop] = sem
return sem
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
@@ -81,7 +92,7 @@ def _schedule_log(
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
) -> None:
async def _safe_log() -> None:
async with _log_semaphore:
async with _get_log_semaphore():
try:
await db_client.log_platform_cost(entry)
except Exception:
@@ -125,9 +136,9 @@ def resolve_tracking(
2. Heuristic fallback: infer from `provider_cost`/token counts, then
from provider name for per-character / per-second billing.
"""
# 1. Block explicitly declared its cost type
if stats.provider_cost_type:
return stats.provider_cost_type, stats.provider_cost or 0.0
# 1. Block explicitly declared its cost type (only when an amount is present)
if stats.provider_cost_type and stats.provider_cost is not None:
return stats.provider_cost_type, stats.provider_cost
# 2. Provider returned actual USD cost (OpenRouter, Exa)
if stats.provider_cost is not None:
@@ -217,9 +228,11 @@ async def log_system_credential_cost(
# Only treat provider_cost as USD when the tracking type says so.
# For other types (items, characters, per_run, ...) the
# provider_cost field holds the raw amount, not a dollar value.
# Use tracking_amount (the normalized value from resolve_tracking)
# rather than raw stats.provider_cost to avoid unit mismatches.
cost_microdollars = None
if tracking_type == "cost_usd":
cost_microdollars = usd_to_microdollars(stats.provider_cost)
cost_microdollars = usd_to_microdollars(tracking_amount)
meta: dict[str, Any] = {
"tracking_type": tracking_type,