fix(platform): address autogpt-reviewer blockers and should-fix items

- Fix LLM retry double-counting: track tokens per attempt but only merge
  provider_cost on the successful attempt, not across all retries
- Add drain_pending_cost_logs() to platform_cost.py; update cost_tracking
  to drain both executor and copilot task sets on shutdown
- Remove prohibited dark: Tailwind classes from PlatformCostContent error
  div, replace with Alert component (design system error variant)
- Add block-level cost tracking tests for: JinaEmbeddingBlock (with/without
  usage), UnrealTextToSpeechBlock (character count), GoogleMapsSearchBlock
  (place count), AddLeadToCampaignBlock (lead count)
- Add __iadd__ edge case tests: provider_cost_type first-write-to-None and
  None does not overwrite existing value
- Rename metadata key provider_cost_usd to provider_cost_raw (value unit
  varies by tracking type; only cost_usd uses USD)
- Add test verifying per_run providers have no provider_cost_raw in metadata
This commit is contained in:
Zamil Majdy
2026-04-07 15:05:44 +07:00
parent f87bbd5966
commit 16d696edcc
7 changed files with 422 additions and 12 deletions

View File

@@ -342,3 +342,296 @@ class TestSearchOrganizationsBlockCostTracking:
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# JinaEmbeddingBlock — token count from usage.total_tokens
# ---------------------------------------------------------------------------
class TestJinaEmbeddingBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_token_count(self):
"""provider token count is recorded when API returns usage.total_tokens."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
"usage": {"total_tokens": 42},
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello world"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].input_token_count == 42
@pytest.mark.asyncio
async def test_no_merge_stats_when_usage_absent(self):
"""When API response omits usage field, merge_stats is not called."""
from backend.blocks.jina._auth import TEST_CREDENTIALS as JINA_CREDS
from backend.blocks.jina._auth import TEST_CREDENTIALS_INPUT as JINA_CREDS_INPUT
from backend.blocks.jina.embeddings import JinaEmbeddingBlock
block = JinaEmbeddingBlock()
api_response = {
"data": [{"embedding": [0.1, 0.2, 0.3]}],
}
mock_resp = MagicMock()
mock_resp.json.return_value = api_response
accumulated: list[NodeExecutionStats] = []
with (
patch(
"backend.blocks.jina.embeddings.Requests.post",
new_callable=AsyncMock,
return_value=mock_resp,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = JinaEmbeddingBlock.Input(
texts=["hello"],
credentials=JINA_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=JINA_CREDS):
pass
assert accumulated == []
# ---------------------------------------------------------------------------
# UnrealTextToSpeechBlock — character count from input text length
# ---------------------------------------------------------------------------
class TestUnrealTextToSpeechBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_character_count(self):
"""provider_cost equals len(text) with type='characters'."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
test_text = "Hello, world!"
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text=test_text,
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == float(len(test_text))
assert stats.provider_cost_type == "characters"
@pytest.mark.asyncio
async def test_empty_text_gives_zero_characters(self):
"""An empty text string results in provider_cost=0.0."""
from backend.blocks.text_to_speech_block import TEST_CREDENTIALS as TTS_CREDS
from backend.blocks.text_to_speech_block import (
TEST_CREDENTIALS_INPUT as TTS_CREDS_INPUT,
)
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
block = UnrealTextToSpeechBlock()
with (
patch.object(
UnrealTextToSpeechBlock,
"call_unreal_speech_api",
new_callable=AsyncMock,
return_value={"OutputUri": "https://example.com/audio.mp3"},
),
patch.object(block, "merge_stats") as mock_merge,
):
input_data = UnrealTextToSpeechBlock.Input(
text="",
credentials=TTS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=TTS_CREDS):
pass
mock_merge.assert_called_once()
stats = mock_merge.call_args[0][0]
assert stats.provider_cost == 0.0
assert stats.provider_cost_type == "characters"
# ---------------------------------------------------------------------------
# GoogleMapsSearchBlock — item count from search_places results
# ---------------------------------------------------------------------------
class TestGoogleMapsSearchBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_place_count(self):
"""provider_cost equals number of returned places, type == 'items'."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
fake_places = [{"name": f"Place{i}", "address": f"Addr{i}"} for i in range(4)]
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=fake_places,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="coffee shops",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 4.0
assert accumulated[0].provider_cost_type == "items"
@pytest.mark.asyncio
async def test_empty_results_tracks_zero(self):
"""Zero places returned results in provider_cost=0.0."""
from backend.blocks.google_maps import TEST_CREDENTIALS as MAPS_CREDS
from backend.blocks.google_maps import (
TEST_CREDENTIALS_INPUT as MAPS_CREDS_INPUT,
)
from backend.blocks.google_maps import GoogleMapsSearchBlock
block = GoogleMapsSearchBlock()
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
GoogleMapsSearchBlock,
"search_places",
return_value=[],
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = GoogleMapsSearchBlock.Input(
query="nothing here",
credentials=MAPS_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=MAPS_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 0.0
assert accumulated[0].provider_cost_type == "items"
# ---------------------------------------------------------------------------
# SmartLeadAddLeadsBlock — item count from lead_list length
# ---------------------------------------------------------------------------
class TestSmartLeadAddLeadsBlockCostTracking:
@pytest.mark.asyncio
async def test_merge_stats_called_with_lead_count(self):
"""provider_cost equals number of leads uploaded, type == 'items'."""
from backend.blocks.smartlead._auth import TEST_CREDENTIALS as SL_CREDS
from backend.blocks.smartlead._auth import (
TEST_CREDENTIALS_INPUT as SL_CREDS_INPUT,
)
from backend.blocks.smartlead.campaign import AddLeadToCampaignBlock
from backend.blocks.smartlead.models import (
AddLeadsToCampaignResponse,
LeadInput,
)
block = AddLeadToCampaignBlock()
fake_leads = [
LeadInput(first_name="Alice", last_name="A", email="alice@example.com"),
LeadInput(first_name="Bob", last_name="B", email="bob@example.com"),
]
fake_response = AddLeadsToCampaignResponse(
ok=True,
upload_count=2,
total_leads=2,
block_count=0,
duplicate_count=0,
invalid_email_count=0,
invalid_emails=[],
already_added_to_campaign=0,
unsubscribed_leads=[],
is_lead_limit_exhausted=False,
lead_import_stopped_count=0,
bounce_count=0,
)
accumulated: list[NodeExecutionStats] = []
with (
patch.object(
AddLeadToCampaignBlock,
"add_leads_to_campaign",
new_callable=AsyncMock,
return_value=fake_response,
),
patch.object(
block, "merge_stats", side_effect=lambda s: accumulated.append(s)
),
):
input_data = AddLeadToCampaignBlock.Input(
campaign_id=123,
lead_list=fake_leads,
credentials=SL_CREDS_INPUT, # type: ignore[arg-type]
)
async for _ in block.run(input_data, credentials=SL_CREDS):
pass
assert len(accumulated) == 1
assert accumulated[0].provider_cost == 2.0
assert accumulated[0].provider_cost_type == "items"

View File

@@ -1432,6 +1432,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
error_feedback_message = ""
llm_model = input_data.model
last_attempt_cost: float | None = None
for retry_count in range(input_data.retry):
logger.debug(f"LLM request: {prompt}")
@@ -1449,13 +1450,15 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=input_data.max_tokens,
)
response_text = llm_response.response
cost_stats = NodeExecutionStats(
# Merge token counts for every attempt (each call costs tokens).
# provider_cost (actual USD) is tracked separately and only merged
# on success to avoid double-counting across retries.
token_stats = NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
if llm_response.provider_cost is not None:
cost_stats.provider_cost = llm_response.provider_cost
self.merge_stats(cost_stats)
self.merge_stats(token_stats)
last_attempt_cost = llm_response.provider_cost
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:
@@ -1524,6 +1527,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", response_obj
@@ -1544,6 +1548,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
provider_cost=last_attempt_cost,
)
)
yield "response", {"response": response_text}

View File

@@ -209,3 +209,41 @@ class TestNodeExecutionStatsIadd:
a = NodeExecutionStats()
result = a.__iadd__("not a stats") # type: ignore[arg-type]
assert result is NotImplemented
def test_error_none_does_not_clear_existing_error(self):
a = NodeExecutionStats(error="existing error")
b = NodeExecutionStats(error=None)
a += b
assert a.error == "existing error"
def test_provider_cost_none_does_not_clear_existing_cost(self):
a = NodeExecutionStats(provider_cost=0.05)
b = NodeExecutionStats(provider_cost=None)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_accumulates_when_both_set(self):
a = NodeExecutionStats(provider_cost=0.01)
b = NodeExecutionStats(provider_cost=0.02)
a += b
assert abs((a.provider_cost or 0) - 0.03) < 1e-9
def test_provider_cost_first_write_from_none(self):
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost=0.05)
a += b
assert a.provider_cost == 0.05
def test_provider_cost_type_first_write_from_none(self):
"""Writing provider_cost_type into a stats with None sets it."""
a = NodeExecutionStats()
b = NodeExecutionStats(provider_cost_type="characters")
a += b
assert a.provider_cost_type == "characters"
def test_provider_cost_type_none_does_not_overwrite(self):
"""A None provider_cost_type from other must not clear an existing value."""
a = NodeExecutionStats(provider_cost_type="tokens")
b = NodeExecutionStats()
a += b
assert a.provider_cost_type == "tokens"

View File

@@ -113,6 +113,20 @@ def schedule_cost_log(entry: PlatformCostEntry) -> None:
task.add_done_callback(_pending_log_tasks.discard)
async def drain_pending_cost_logs() -> None:
"""Await all in-flight cost log tasks before shutdown.
Call this from ExecutionManager.cleanup() (or equivalent teardown hook)
to ensure no cost entries are silently dropped during a rolling deployment.
Tasks that were already completed are no-ops; only genuinely in-flight
tasks cause a real wait.
"""
pending = list(_pending_log_tasks)
if pending:
logger.info("Draining %d pending cost log task(s)…", len(pending))
await asyncio.gather(*pending, return_exceptions=True)
def _json_or_none(data: dict[str, Any] | None) -> str | None:
if data is None:
return None
@@ -310,6 +324,8 @@ async def get_platform_cost_logs(
page: int = 1,
page_size: int = 50,
) -> tuple[list[CostLogRow], int]:
if start is None:
start = datetime.now(tz=timezone.utc) - timedelta(days=DEFAULT_DASHBOARD_DAYS)
where_sql, params = _build_where(start, end, provider, user_id, "p")
offset = (page - 1) * page_size

View File

@@ -36,6 +36,36 @@ _WALLTIME_BILLED_PROVIDERS = frozenset(
_pending_log_tasks: set[asyncio.Task] = set()
async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
"""Await all in-flight cost log tasks with a timeout.
Drains both the executor cost log tasks (_pending_log_tasks in this module,
used for block execution cost tracking via DatabaseManagerAsyncClient) and
the copilot cost log tasks (from platform_cost.schedule_cost_log, used by
token_tracking for copilot LLM turns).
Call this during graceful shutdown to flush pending INSERT tasks before
the process exits. Tasks that don't complete within `timeout` seconds are
abandoned and their failures are already logged by _safe_log.
"""
all_pending = list(_pending_log_tasks)
if all_pending:
logger.info("Draining %d executor cost log task(s)", len(all_pending))
_, still_pending = await asyncio.wait(all_pending, timeout=timeout)
if still_pending:
logger.warning(
"%d executor cost log task(s) did not complete within %.1fs",
len(still_pending),
timeout,
)
# Also drain copilot cost log tasks (platform_cost._pending_log_tasks)
from backend.data.platform_cost import ( # noqa: PLC0415
drain_pending_cost_logs as _drain_copilot,
)
await _drain_copilot()
def _schedule_log(
db_client: "DatabaseManagerAsyncClient", entry: PlatformCostEntry
) -> None:
@@ -175,7 +205,9 @@ async def log_system_credential_cost(
if credit_cost:
meta["credit_cost"] = credit_cost
if stats.provider_cost is not None:
meta["provider_cost_usd"] = stats.provider_cost
# Use 'provider_cost_raw' — the value's unit varies by tracking
# type (USD for cost_usd, count for items/characters/per_run, etc.)
meta["provider_cost_raw"] = stats.provider_cost
_schedule_log(
db_client,

View File

@@ -311,7 +311,7 @@ class TestLogSystemCredentialCost:
assert entry.cost_microdollars == 1500
assert entry.tracking_type == "cost_usd"
assert entry.metadata["tracking_type"] == "cost_usd"
assert entry.metadata["provider_cost_usd"] == 0.0015
assert entry.metadata["provider_cost_raw"] == 0.0015
@pytest.mark.asyncio
async def test_model_name_enum_converted_to_str(self):
@@ -420,6 +420,34 @@ class TestLogSystemCredentialCost:
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.cost_microdollars == 1500
@pytest.mark.asyncio
async def test_per_run_metadata_has_no_provider_cost_raw(self):
"""For per-run providers (google_maps etc), provider_cost_raw is absent
from metadata since stats.provider_cost is None."""
db_client = _make_db_client()
with (
patch(
"backend.executor.cost_tracking.is_system_credential", return_value=True
),
patch(
"backend.executor.cost_tracking.block_usage_cost",
return_value=(0, None),
),
):
node_exec = _make_node_exec(
inputs={
"credentials": {"id": "sys-cred", "provider": "google_maps"},
}
)
block = _make_block()
stats = NodeExecutionStats() # no provider_cost
await log_system_credential_cost(node_exec, block, stats, db_client)
await asyncio.sleep(0)
entry = db_client.log_platform_cost.call_args[0][0]
assert entry.tracking_type == "per_run"
assert "provider_cost_raw" not in (entry.metadata or {})
# ---------------------------------------------------------------------------
# merge_stats accumulation

View File

@@ -1,5 +1,6 @@
"use client";
import { Alert, AlertDescription } from "@/components/molecules/Alert/Alert";
import { formatMicrodollars } from "../helpers";
import { SummaryCard } from "./SummaryCard";
import { ProviderTable } from "./ProviderTable";
@@ -127,12 +128,9 @@ function PlatformCostContent({ searchParams }: Props) {
</div>
{error && (
<div
role="alert"
className="rounded-lg border border-red-300 bg-red-50 p-4 text-sm text-red-700 dark:border-red-800 dark:bg-red-950/20 dark:text-red-400"
>
{error}
</div>
<Alert variant="error">
<AlertDescription>{error}</AlertDescription>
</Alert>
)}
{loading ? (