mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(platform): address autogpt-reviewer should-fix items
- Remove dead _pending_log_tasks/schedule_cost_log/drain_pending_cost_logs from platform_cost.py (only cost_tracking.py and token_tracking.py have active task registries; drain comment updated to match) - Replace vars(other) iteration in NodeExecutionStats.__iadd__ with type(other).model_fields to avoid any potential __pydantic_extra__ leakage - Fix rate-override clear: onRateOverride(key, null) deletes the key so defaultRateFor() takes effect instead of pinning estimated cost to $0 - Type extract_openrouter_cost parameter as OpenAIChatCompletion - Fix early-return guard in persist_and_record_usage: allow through when all token counts are 0 but cost_usd is provided (fully-cached responses) - Add missing tests: LLM retry cost (only last attempt merged), zero-token copilot cost, Exa search + similar merge_stats coverage
This commit is contained in:
@@ -208,3 +208,127 @@ class TestExaContentsCostTracking:
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.0)
|
||||
|
||||
|
||||
class TestExaSearchCostTracking:
|
||||
"""ExaSearchBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.008)
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.008)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.search import ExaSearchBlock
|
||||
|
||||
block = ExaSearchBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.resolved_search_type = None
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.search.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.search = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(query="test query", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
|
||||
class TestExaSimilarCostTracking:
|
||||
"""ExaFindSimilarBlock merges cost_dollars.total as provider_cost."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_dollars_total_is_merged(self):
|
||||
"""When the SDK response includes cost_dollars, its total is merged."""
|
||||
from backend.blocks.exa.helpers import CostDollars
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-1"
|
||||
mock_sdk_response.cost_dollars = CostDollars(total=0.015)
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 1
|
||||
assert merged[0].provider_cost == pytest.approx(0.015)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_cost_dollars_skips_merge(self):
|
||||
"""When cost_dollars is absent, merge_stats is not called."""
|
||||
from backend.blocks.exa.similar import ExaFindSimilarBlock
|
||||
|
||||
block = ExaFindSimilarBlock()
|
||||
merged: list[NodeExecutionStats] = []
|
||||
block.merge_stats = lambda s: merged.append(s) # type: ignore[assignment]
|
||||
|
||||
mock_sdk_response = MagicMock()
|
||||
mock_sdk_response.results = []
|
||||
mock_sdk_response.context = None
|
||||
mock_sdk_response.request_id = "req-2"
|
||||
mock_sdk_response.cost_dollars = None
|
||||
|
||||
with patch("backend.blocks.exa.similar.AsyncExa") as mock_exa_cls:
|
||||
mock_exa = MagicMock()
|
||||
mock_exa.find_similar = AsyncMock(return_value=mock_sdk_response)
|
||||
mock_exa_cls.return_value = mock_exa
|
||||
|
||||
async for _ in block.run(
|
||||
block.Input(url="https://example.com", credentials=TEST_CREDENTIALS_INPUT), # type: ignore[arg-type]
|
||||
credentials=TEST_CREDENTIALS,
|
||||
):
|
||||
pass
|
||||
|
||||
assert len(merged) == 0
|
||||
|
||||
@@ -13,6 +13,7 @@ import ollama
|
||||
import openai
|
||||
from anthropic.types import ToolParam
|
||||
from groq import AsyncGroq
|
||||
from openai.types.chat import ChatCompletion as OpenAIChatCompletion
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.blocks._base import (
|
||||
@@ -772,7 +773,7 @@ def convert_openai_tool_fmt_to_anthropic(
|
||||
return anthropic_tools
|
||||
|
||||
|
||||
def extract_openrouter_cost(response) -> float | None:
|
||||
def extract_openrouter_cost(response: OpenAIChatCompletion) -> float | None:
|
||||
"""Extract OpenRouter's `x-total-cost` header from an OpenAI SDK response.
|
||||
|
||||
OpenRouter returns the per-request USD cost in a response header. The
|
||||
|
||||
@@ -199,6 +199,66 @@ class TestLLMStatsTracking:
|
||||
assert block.execution_stats.llm_call_count == 2 # retry_count + 1 = 1 + 1 = 2
|
||||
assert block.execution_stats.llm_retry_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_cost_uses_last_attempt_only(self):
|
||||
"""provider_cost is only merged from the final successful attempt.
|
||||
|
||||
Intermediate retry costs are intentionally dropped to avoid
|
||||
double-counting: the cost of failed attempts is captured in
|
||||
last_attempt_cost only when the loop eventually succeeds.
|
||||
"""
|
||||
import backend.blocks.llm as llm
|
||||
|
||||
block = llm.AIStructuredResponseGeneratorBlock()
|
||||
call_count = 0
|
||||
|
||||
async def mock_llm_call(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# First attempt: fails validation, returns cost $0.01
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"wrong": "key"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
reasoning=None,
|
||||
provider_cost=0.01,
|
||||
)
|
||||
# Second attempt: succeeds, returns cost $0.02
|
||||
return llm.LLMResponse(
|
||||
raw_response="",
|
||||
prompt=[],
|
||||
response='<json_output id="test123456">{"key1": "value1", "key2": "value2"}</json_output>',
|
||||
tool_calls=None,
|
||||
prompt_tokens=20,
|
||||
completion_tokens=10,
|
||||
reasoning=None,
|
||||
provider_cost=0.02,
|
||||
)
|
||||
|
||||
block.llm_call = mock_llm_call # type: ignore
|
||||
|
||||
input_data = llm.AIStructuredResponseGeneratorBlock.Input(
|
||||
prompt="Test prompt",
|
||||
expected_format={"key1": "desc1", "key2": "desc2"},
|
||||
model=llm.DEFAULT_LLM_MODEL,
|
||||
credentials=llm.TEST_CREDENTIALS_INPUT, # type: ignore
|
||||
retry=2,
|
||||
)
|
||||
|
||||
with patch("secrets.token_hex", return_value="test123456"):
|
||||
async for _ in block.run(input_data, credentials=llm.TEST_CREDENTIALS):
|
||||
pass
|
||||
|
||||
# Only the final successful attempt's cost is merged
|
||||
assert block.execution_stats.provider_cost == pytest.approx(0.02)
|
||||
# Tokens from both attempts accumulate
|
||||
assert block.execution_stats.input_token_count == 30
|
||||
assert block.execution_stats.output_token_count == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ai_text_summarizer_multiple_chunks(self):
|
||||
"""Test that AITextSummarizerBlock correctly accumulates stats across multiple chunks."""
|
||||
|
||||
@@ -93,12 +93,13 @@ async def persist_and_record_usage(
|
||||
cache_read_tokens = max(0, cache_read_tokens)
|
||||
cache_creation_tokens = max(0, cache_creation_tokens)
|
||||
|
||||
if (
|
||||
no_tokens = (
|
||||
prompt_tokens <= 0
|
||||
and completion_tokens <= 0
|
||||
and cache_read_tokens <= 0
|
||||
and cache_creation_tokens <= 0
|
||||
):
|
||||
)
|
||||
if no_tokens and cost_usd is None:
|
||||
return 0
|
||||
|
||||
# total_tokens = prompt + completion. Cache tokens are tracked
|
||||
|
||||
@@ -501,3 +501,39 @@ class TestPlatformCostLogging:
|
||||
assert entry.metadata["cache_read_tokens"] == 5000
|
||||
assert entry.metadata["cache_creation_tokens"] == 300
|
||||
assert entry.metadata["source"] == "copilot"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_cost_only_when_tokens_zero(self):
|
||||
"""Zero prompt+completion tokens with cost_usd set still logs the entry."""
|
||||
mock_log = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.token_tracking.record_token_usage",
|
||||
new_callable=AsyncMock,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.token_tracking.platform_cost_db",
|
||||
return_value=type(
|
||||
"FakePlatformCostDb", (), {"log_platform_cost": mock_log}
|
||||
)(),
|
||||
),
|
||||
):
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id="user-cached",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cost_usd=0.005,
|
||||
model="claude-3-5-sonnet",
|
||||
provider="anthropic",
|
||||
log_prefix="[SDK]",
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# Guard: total_tokens == 0 but cost_usd is set — must still log
|
||||
mock_log.assert_awaited_once()
|
||||
entry = mock_log.call_args[0][0]
|
||||
assert entry.user_id == "user-cached"
|
||||
assert entry.tracking_type == "cost_usd"
|
||||
assert entry.cost_microdollars == 5000
|
||||
assert entry.input_tokens == 0
|
||||
assert entry.output_tokens == 0
|
||||
|
||||
@@ -867,23 +867,12 @@ class NodeExecutionStats(BaseModel):
|
||||
if not isinstance(other, NodeExecutionStats):
|
||||
return NotImplemented
|
||||
|
||||
# vars() returns __dict__ for fields that are set (plus extras via
|
||||
# model_config.extra='allow') — much cheaper than model_dump() which
|
||||
# validates + serialises every field.
|
||||
#
|
||||
# Pydantic v2 stores all field values in __dict__, so vars() is
|
||||
# equivalent to model_dump() for our declared fields. Internal keys
|
||||
# (__pydantic_fields_set__, etc.) start with __ and are harmless —
|
||||
# setattr on those would update the instance's private Pydantic state,
|
||||
# but in practice they don't appear in __dict__ for field keys.
|
||||
other_fields = vars(other)
|
||||
self_fields = vars(self)
|
||||
|
||||
for key, value in other_fields.items():
|
||||
for key in type(other).model_fields:
|
||||
value = getattr(other, key)
|
||||
if value is None:
|
||||
# Never overwrite an existing value with None
|
||||
continue
|
||||
current = self_fields.get(key)
|
||||
current = getattr(self, key, None)
|
||||
if current is None:
|
||||
# Field doesn't exist yet or is None, just set it
|
||||
setattr(self, key, value)
|
||||
|
||||
@@ -105,38 +105,6 @@ async def log_platform_cost_safe(entry: PlatformCostEntry) -> None:
|
||||
)
|
||||
|
||||
|
||||
# Hold strong references to in-flight log tasks to prevent GC.
|
||||
# Tasks remove themselves on completion via add_done_callback.
|
||||
# Concurrent DB inserts are bounded by _log_semaphore (50) to provide
|
||||
# back-pressure under high load or DB slowness.
|
||||
_pending_log_tasks: set["asyncio.Task[None]"] = set()
|
||||
|
||||
|
||||
def schedule_cost_log(entry: PlatformCostEntry) -> None:
|
||||
"""Schedule a fire-and-forget cost log insert.
|
||||
|
||||
Shared by cost_tracking and token_tracking so both modules drain
|
||||
the same task set during shutdown.
|
||||
"""
|
||||
task = asyncio.create_task(log_platform_cost_safe(entry))
|
||||
_pending_log_tasks.add(task)
|
||||
task.add_done_callback(_pending_log_tasks.discard)
|
||||
|
||||
|
||||
async def drain_pending_cost_logs() -> None:
|
||||
"""Await all in-flight cost log tasks before shutdown.
|
||||
|
||||
Call this from ExecutionManager.cleanup() (or equivalent teardown hook)
|
||||
to ensure no cost entries are silently dropped during a rolling deployment.
|
||||
Tasks that were already completed are no-ops; only genuinely in-flight
|
||||
tasks cause a real wait.
|
||||
"""
|
||||
pending = list(_pending_log_tasks)
|
||||
if pending:
|
||||
logger.info("Draining %d pending cost log task(s)…", len(pending))
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
|
||||
def _json_or_none(data: dict[str, Any] | None) -> str | None:
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
@@ -41,8 +41,8 @@ async def drain_pending_cost_logs(timeout: float = 5.0) -> None:
|
||||
|
||||
Drains both the executor cost log tasks (_pending_log_tasks in this module,
|
||||
used for block execution cost tracking via DatabaseManagerAsyncClient) and
|
||||
the copilot cost log tasks (from platform_cost.schedule_cost_log, used by
|
||||
token_tracking for copilot LLM turns).
|
||||
the copilot cost log tasks (token_tracking._pending_log_tasks, used for
|
||||
copilot LLM turns via platform_cost_db()).
|
||||
|
||||
Call this during graceful shutdown to flush pending INSERT tasks before
|
||||
the process exits. Tasks that don't complete within `timeout` seconds are
|
||||
|
||||
@@ -12,7 +12,7 @@ import { TrackingBadge } from "./TrackingBadge";
|
||||
interface Props {
|
||||
data: ProviderCostSummary[];
|
||||
rateOverrides: Record<string, number>;
|
||||
onRateOverride: (key: string, val: number) => void;
|
||||
onRateOverride: (key: string, val: number | null) => void;
|
||||
}
|
||||
|
||||
function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
@@ -79,7 +79,7 @@ function ProviderTable({ data, rateOverrides, onRateOverride }: Props) {
|
||||
const val = parseFloat(e.target.value);
|
||||
if (!isNaN(val)) onRateOverride(key, val);
|
||||
else if (e.target.value === "")
|
||||
onRateOverride(key, 0);
|
||||
onRateOverride(key, null);
|
||||
}}
|
||||
/>
|
||||
<span
|
||||
|
||||
@@ -95,8 +95,14 @@ export function usePlatformCostContent(searchParams: InitialSearchParams) {
|
||||
});
|
||||
}
|
||||
|
||||
function handleRateOverride(key: string, val: number) {
|
||||
setRateOverrides((prev) => ({ ...prev, [key]: val }));
|
||||
function handleRateOverride(key: string, val: number | null) {
|
||||
setRateOverrides((prev) => {
|
||||
if (val === null) {
|
||||
const { [key]: _, ...rest } = prev;
|
||||
return rest;
|
||||
}
|
||||
return { ...prev, [key]: val };
|
||||
});
|
||||
}
|
||||
|
||||
const totalEstimatedCost =
|
||||
|
||||
Reference in New Issue
Block a user