mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): address 4 unresolved review threads on cost tracking
1. cost_tracking.py: replace shared _log_semaphore with per-loop dict (_log_semaphores + _get_log_semaphore()) — asyncio.Semaphore is not thread-safe and must not be shared across executor worker threads/loops 2. cost_tracking.py: only honor provider_cost_type when provider_cost is also present (not None); use tracking_amount (not raw stats.provider_cost) in usd_to_microdollars() to avoid unit mismatches 3. token_tracking.py: add semaphore to _schedule_cost_log (same pattern as cost_tracking.py) to bound concurrent DB inserts under load; fix forward-reference string in _pending_log_tasks type annotation 4. baseline/service.py: validate x-total-cost header with math.isfinite and max(0.0, cost) guard before accumulating — rejects nan/inf values that float() accepts but that should never reach the persistence path
This commit is contained in:
@@ -9,6 +9,7 @@ shared tool registry as the SDK path.
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -441,7 +442,9 @@ async def _baseline_llm_caller(
|
||||
if raw_resp and hasattr(raw_resp, "headers"):
|
||||
cost_header = raw_resp.headers.get("x-total-cost")
|
||||
if cost_header:
|
||||
state.cost_usd = (state.cost_usd or 0.0) + float(cost_header)
|
||||
cost = float(cost_header)
|
||||
if math.isfinite(cost):
|
||||
state.cost_usd = (state.cost_usd or 0.0) + max(0.0, cost)
|
||||
except (ValueError, AttributeError, UnboundLocalError):
|
||||
pass
|
||||
|
||||
|
||||
@@ -22,22 +22,35 @@ from .rate_limit import record_token_usage
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Hold strong references to in-flight cost log tasks to prevent GC.
|
||||
_pending_log_tasks: set["asyncio.Task[None]"] = set()
|
||||
_pending_log_tasks: set[asyncio.Task[None]] = set()
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
def _schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log via DatabaseManagerAsyncClient RPC."""
|
||||
|
||||
async def _safe_log() -> None:
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await platform_cost_db().log_platform_cost(entry)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to log platform cost for user=%s provider=%s block=%s",
|
||||
entry.user_id,
|
||||
entry.provider,
|
||||
entry.block_name,
|
||||
)
|
||||
|
||||
task = asyncio.create_task(_safe_log())
|
||||
_pending_log_tasks.add(task)
|
||||
|
||||
@@ -34,8 +34,19 @@ _WALLTIME_BILLED_PROVIDERS = frozenset(
|
||||
# Hold strong references to in-flight log tasks so the event loop doesn't
|
||||
# garbage-collect them mid-execution. Tasks remove themselves on completion.
|
||||
_pending_log_tasks: set[asyncio.Task] = set()
|
||||
# Bound concurrent DB inserts to avoid unbounded queue growth under load.
|
||||
_log_semaphore = asyncio.Semaphore(50)
|
||||
# Per-loop semaphores: asyncio.Semaphore is not thread-safe and must not be
|
||||
# shared across event loops running in different threads. Key by loop instance
|
||||
# so each executor worker thread gets its own semaphore.
|
||||
_log_semaphores: dict[asyncio.AbstractEventLoop, asyncio.Semaphore] = {}
|
||||
|
||||
|
||||
def _get_log_semaphore() -> asyncio.Semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
sem = _log_semaphores.get(loop)
|
||||
if sem is None:
|
||||
sem = asyncio.Semaphore(50)
|
||||
_log_semaphores[loop] = sem
|
||||
return sem
|
||||
|
||||
|
||||
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
@@ -81,7 +92,7 @@ def _schedule_log(
|
||||
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
|
||||
) -> None:
|
||||
async def _safe_log() -> None:
|
||||
async with _log_semaphore:
|
||||
async with _get_log_semaphore():
|
||||
try:
|
||||
await db_client.log_platform_cost(entry)
|
||||
except Exception:
|
||||
@@ -125,9 +136,9 @@ def resolve_tracking(
|
||||
2. Heuristic fallback: infer from `provider_cost`/token counts, then
|
||||
from provider name for per-character / per-second billing.
|
||||
"""
|
||||
# 1. Block explicitly declared its cost type
|
||||
if stats.provider_cost_type:
|
||||
return stats.provider_cost_type, stats.provider_cost or 0.0
|
||||
# 1. Block explicitly declared its cost type (only when an amount is present)
|
||||
if stats.provider_cost_type and stats.provider_cost is not None:
|
||||
return stats.provider_cost_type, stats.provider_cost
|
||||
|
||||
# 2. Provider returned actual USD cost (OpenRouter, Exa)
|
||||
if stats.provider_cost is not None:
|
||||
@@ -217,9 +228,11 @@ async def log_system_credential_cost(
|
||||
# Only treat provider_cost as USD when the tracking type says so.
|
||||
# For other types (items, characters, per_run, ...) the
|
||||
# provider_cost field holds the raw amount, not a dollar value.
|
||||
# Use tracking_amount (the normalized value from resolve_tracking)
|
||||
# rather than raw stats.provider_cost to avoid unit mismatches.
|
||||
cost_microdollars = None
|
||||
if tracking_type == "cost_usd":
|
||||
cost_microdollars = usd_to_microdollars(stats.provider_cost)
|
||||
cost_microdollars = usd_to_microdollars(tracking_amount)
|
||||
|
||||
meta: dict[str, Any] = {
|
||||
"tracking_type": tracking_type,
|
||||
|
||||
Reference in New Issue
Block a user