fix(platform): address autogpt-reviewer should-fix items

- Remove dead _pending_log_tasks/schedule_cost_log/drain_pending_cost_logs
  from platform_cost.py (only cost_tracking.py and token_tracking.py have
  active task registries; drain comment updated to match)
- Replace vars(other) iteration in NodeExecutionStats.__iadd__ with
  type(other).model_fields to avoid any potential __pydantic_extra__ leakage
- Fix rate-override clear: onRateOverride(key, null) deletes the key so
  defaultRateFor() takes effect instead of pinning estimated cost to $0
- Type extract_openrouter_cost parameter as OpenAIChatCompletion
- Fix early-return guard in persist_and_record_usage: allow through when
  all token counts are 0 but cost_usd is provided (fully-cached responses)
- Add missing tests: LLM retry cost (only last attempt merged), zero-token
  copilot cost, Exa search + similar merge_stats coverage
This commit is contained in:
Zamil Majdy
2026-04-07 17:23:42 +07:00
parent 9b1175473b
commit db6b4444e0
10 changed files with 240 additions and 55 deletions

View File

@@ -208,3 +208,127 @@ class TestExaContentsCostTracking:
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.0)
class TestExaSearchCostTracking:
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.008)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.search import ExaSearchBlock
block = ExaSearchBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.resolved_search_type = None
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0
class TestExaSimilarCostTracking:
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
@pytest.mark.asyncio
async def test_cost_dollars_total_is_merged(self):
"""When the SDK response includes cost_dollars, its total is merged."""
from backend.blocks.exa.helpers import CostDollars
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-1"
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 1
assert merged[0].provider_cost == pytest.approx(0.015)
@pytest.mark.asyncio
async def test_no_cost_dollars_skips_merge(self):
"""When cost_dollars is absent, merge_stats is not called."""
from backend.blocks.exa.similar import ExaFindSimilarBlock
block = ExaFindSimilarBlock()
merged: list[NodeExecutionStats] = []
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
mock_sdk_response = MagicMock()
mock_sdk_response.results = []
mock_sdk_response.context = None
mock_sdk_response.request_id = "req-2"
mock_sdk_response.cost_dollars = None
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
mock_exa = MagicMock()
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
mock_exa_cls.return_value = mock_exa
async for _ in block.run(
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
credentials=TEST_CREDENTIALS,
):
pass
assert len(merged) == 0

View File

@@ -13,6 +13,7 @@ import ollama
import openai
from anthropic.types import ToolParam
from groq import AsyncGroq
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
from pydantic import BaseModel, SecretStr
from backend.blocks._base import (
@@ -772,7 +773,7 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def extract_openrouter_cost(response) -> float | None:
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
OpenRouter returns the per-request USD cost in a response header. The

View File

@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
assert block.execution_stats.llm_retry_count == 1
@pytest.mark.asyncio
async def test_retry_cost_uses_last_attempt_only(self):
"""provider_cost is only merged from the final successful attempt.
Intermediate retry costs are intentionally dropped to avoid
double-counting: the cost of failed attempts is captured in
last_attempt_cost only when the loop eventually succeeds.
"""
import backend.blocks.llm as llm
block = llm.AIStructuredResponseGeneratorBlock()
call_count = 0
async def mock_llm_call(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count == 1:
# First attempt: fails validation, returns cost $0.01
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
tool_calls=None,
prompt_tokens=10,
completion_tokens=5,
reasoning=None,
provider_cost=0.01,
)
# Second attempt: succeeds, returns cost $0.02
return llm.LLMResponse(
raw_response="",
prompt=[],
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
tool_calls=None,
prompt_tokens=20,
completion_tokens=10,
reasoning=None,
provider_cost=0.02,
)
block.llm_call = mock_llm_call # type: ignore
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
prompt="Test prompt",
expected_format={"key1": "desc1", "key2": "desc2"},
model=llm.DEFAULT_LLM_MODEL,
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
retry=2,
)
with patch("secrets.token_hex", return_value="test123456"):
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
pass
# Only the final successful attempt's cost is merged
assert block.execution_stats.provider_cost == pytest.approx(0.02)
# Tokens from both attempts accumulate
assert block.execution_stats.input_token_count == 30
assert block.execution_stats.output_token_count == 15
@pytest.mark.asyncio
async def test_ai_text_summarizer_multiple_chunks(self):
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""

View File

@@ -93,12 +93,13 @@ async def persist_and_record_usage(
cache_read_tokens = max(0, cache_read_tokens)
cache_creation_tokens = max(0, cache_creation_tokens)
if (
no_tokens = (
prompt_tokens <= 0
and completion_tokens <= 0
and cache_read_tokens <= 0
and cache_creation_tokens <= 0
):
)
if no_tokens and cost_usd is None:
return 0
# total_tokens = prompt + completion. Cache tokens are tracked

View File

@@ -501,3 +501,39 @@ class TestPlatformCostLogging:
assert entry.metadata["cache_read_tokens"] == 5000
assert entry.metadata["cache_creation_tokens"] == 300
assert entry.metadata["source"] == "copilot"
@pytest.mark.asyncio
async def test_logs_cost_only_when_tokens_zero(self):
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
mock_log = AsyncMock()
with (
patch(
"backend.copilot.token_tracking.record_token_usage",
new_callable=AsyncMock,
),
patch(
"backend.copilot.token_tracking.platform_cost_db",
return_value=type(
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
)(),
),
):
await persist_and_record_usage(
session=None,
user_id="user-cached",
prompt_tokens=0,
completion_tokens=0,
cost_usd=0.005,
model="claude-3-5-sonnet",
provider="anthropic",
log_prefix="[SDK]",
)
await asyncio.sleep(0)
# Guard: total_tokens == 0 but cost_usd is set — must still log
mock_log.assert_awaited_once()
entry = mock_log.call_args[0][0]
assert entry.user_id == "user-cached"
assert entry.tracking_type == "cost_usd"
assert entry.cost_microdollars == 5000
assert entry.input_tokens == 0
assert entry.output_tokens == 0

View File

@@ -867,23 +867,12 @@ class NodeExecutionStats(BaseModel):
if not isinstance(other, NodeExecutionStats):
return NotImplemented
# vars() returns __dict__ for fields that are set (plus extras via
# model_config.extra='allow') — much cheaper than model_dump() which
# validates + serialises every field.
#
# Pydantic v2 stores all field values in __dict__, so vars() is
# equivalent to model_dump() for our declared fields. Internal keys
# (__pydantic_fields_set__, etc.) start with __ and are harmless —
# setattr on those would update the instance's private Pydantic state,
# but in practice they don't appear in __dict__ for field keys.
other_fields = vars(other)
self_fields = vars(self)
for key, value in other_fields.items():
for key in type(other).model_fields:
value = getattr(other, key)
if value is None:
# Never overwrite an existing value with None
continue
current = self_fields.get(key)
current = getattr(self, key, None)
if current is None:
# Field doesn't exist yet or is None, just set it
setattr(self, key, value)

View File

@@ -105,38 +105,6 @@ async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
)
# Hold strong references to in-flight log tasks to prevent GC.
# Tasks remove themselves on completion via add_done_callback.
# Concurrent DB inserts are bounded by _log_semaphore (50) to provide
# back-pressure under high load or DB slowness.
_pending_log_tasks: set["asyncio.Task[None]"] = set()
def schedule_cost_log(entry: PlatformCostEntry) -> None:
"""Schedule a fire-and-forget cost log insert.
Shared by cost_tracking and token_tracking so both modules drain
the same task set during shutdown.
"""
task = asyncio.create_task(log_platform_cost_safe(entry))
_pending_log_tasks.add(task)
task.add_done_callback(_pending_log_tasks.discard)
async def drain_pending_cost_logs() -> None:
"""Await all in-flight cost log tasks before shutdown.
Call this from ExecutionManager.cleanup() (or equivalent teardown hook)
to ensure no cost entries are silently dropped during a rolling deployment.
Tasks that were already completed are no-ops; only genuinely in-flight
tasks cause a real wait.
"""
pending = list(_pending_log_tasks)
if pending:
logger.info("Draining %d pending cost log task(s)…", len(pending))
await asyncio.gather(*pending, return_exceptions=True)
def _json_or_none(data: dict[str, Any] | None) -> str | None:
if data is None:
return None

View File

@@ -41,8 +41,8 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
Drains both the executor cost log tasks (_pending_log_tasks in this module,
used for block execution cost tracking via DatabaseManagerAsyncClient) and
the copilot cost log tasks (from platform_cost.schedule_cost_log, used by
token_tracking for copilot LLM turns).
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
copilot LLM turns via platform_cost_db()).
Call this during graceful shutdown to flush pending INSERT tasks before
the process exits. Tasks that don't complete within `timeout` seconds are

View File

@@ -12,7 +12,7 @@ import { TrackingBadge } from "./TrackingBadge";
interface Props {
data: ProviderCostSummary[];
rateOverrides: Record<string, number>;
onRateOverride: (key: string, val: number) => void;
onRateOverride: (key: string, val: number | null) => void;
}
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
@@ -79,7 +79,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
const val = parseFloat(e.target.value);
if (!isNaN(val)) onRateOverride(key, val);
else if (e.target.value === "")
onRateOverride(key, 0);
onRateOverride(key, null);
}}
/>
<span

View File

@@ -95,8 +95,14 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
});
}
function handleRateOverride(key: string, val: number) {
setRateOverrides((prev) => ({ ...prev, [key]: val }));
function handleRateOverride(key: string, val: number | null) {
setRateOverrides((prev) => {
if (val === null) {
const { [key]: _, ...rest } = prev;
return rest;
}
return { ...prev, [key]: val };
});
}
const totalEstimatedCost =