refactor(platform): use provider_cost_type Literal instead of output_size misuse

Blocks previously called merge_stats(NodeExecutionStats(output_size=...))
to signal "per-request" billing or "N items returned", but `output_size`
is semantically the output payload byte count and is always overridden
by the executor wrapper (manager.py:440 = len(json.dumps(output_data))).
Those calls were silently dead code.

Changes:
- Add ProviderCostType Literal enum on NodeExecutionStats with the
  canonical set of tracking types: cost_usd, tokens, characters,
  sandbox_seconds, walltime_seconds, per_run, items.
- Add provider_cost_type field to NodeExecutionStats so blocks can
  declare their billing model explicitly instead of resolve_tracking
  guessing from provider name.
- resolve_tracking honors provider_cost_type first, falling back to
  heuristics only when not set.
- Remove 26 dead merge_stats(output_size=1) calls across 15 blocks.
- Replace 5 merge_stats(output_size=len(X)) calls with explicit
  provider_cost+provider_cost_type (items/characters) so the count
  is preserved through the wrapper's output_size override.
- Clean up unused NodeExecutionStats imports in 14 files.
- Add tests for block-declared provider_cost_type pathway.
This commit is contained in:
Zamil Majdy
2026-04-05 14:56:44 +02:00
parent 78b95f8a76
commit 44714f1b25
22 changed files with 88 additions and 60 deletions

View File

@@ -18,7 +18,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -359,7 +358,6 @@ class AIShortformVideoCreatorBlock(Block):
execution_context=execution_context,
return_format="for_block_output",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "video_url", stored_url
@@ -567,7 +565,6 @@ class AIAdMakerVideoCreatorBlock(Block):
execution_context=execution_context,
return_format="for_block_output",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "video_url", stored_url
@@ -763,5 +760,4 @@ class AIScreenshotToVideoAdBlock(Block):
execution_context=execution_context,
return_format="for_block_output",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "video_url", stored_url

View File

@@ -218,7 +218,11 @@ To find IDs, identify the values for organization_id when you call this endpoint
) -> BlockOutput:
query = SearchOrganizationsRequest(**input_data.model_dump())
organizations = await self.search_organizations(query, credentials)
self.merge_stats(NodeExecutionStats(output_size=len(organizations)))
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(organizations)), provider_cost_type="items"
)
)
for organization in organizations:
yield "organization", organization
yield "organizations", organizations

View File

@@ -366,5 +366,9 @@ class SearchPeopleBlock(Block):
*(enrich_or_fallback(person) for person in people)
)
self.merge_stats(NodeExecutionStats(output_size=len(people)))
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(people)), provider_cost_type="items"
)
)
yield "people", people

View File

@@ -13,7 +13,7 @@ from backend.blocks.apollo._auth import (
ApolloCredentialsInput,
)
from backend.blocks.apollo.models import Contact, EnrichPersonRequest
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class GetPersonDetailBlock(Block):
@@ -141,5 +141,4 @@ class GetPersonDetailBlock(Block):
**kwargs,
) -> BlockOutput:
query = EnrichPersonRequest(**input_data.model_dump())
self.merge_stats(NodeExecutionStats(output_size=1))
yield "contact", await self.enrich_person(query, credentials)

View File

@@ -17,7 +17,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -343,7 +342,6 @@ class ExecuteCodeBlock(Block, BaseE2BExecutorMixin):
# Determine result object shape & filter out empty formats
main_result, results = self.process_execution_results(results)
self.merge_stats(NodeExecutionStats(output_size=1))
if main_result:
yield "main_result", main_result
yield "results", results
@@ -469,7 +467,6 @@ class InstantiateCodeSandboxBlock(Block, BaseE2BExecutorMixin):
setup_commands=input_data.setup_commands,
timeout=input_data.timeout,
)
self.merge_stats(NodeExecutionStats(output_size=1))
if sandbox_id:
yield "sandbox_id", sandbox_id
else:
@@ -580,7 +577,6 @@ class ExecuteCodeStepBlock(Block, BaseE2BExecutorMixin):
# Determine result object shape & filter out empty formats
main_result, results = self.process_execution_results(results)
self.merge_stats(NodeExecutionStats(output_size=1))
if main_result:
yield "main_result", main_result
yield "results", results

View File

@@ -15,12 +15,7 @@ from backend.blocks._base import (
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
NodeExecutionStats,
SchemaField,
)
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.type import MediaFileType
from ._api import (
@@ -200,7 +195,6 @@ class GetLinkedinProfileBlock(Block):
include_social_media=input_data.include_social_media,
include_extra=input_data.include_extra,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "profile", profile
except Exception as e:
logger.error(f"Error fetching LinkedIn profile: {str(e)}")
@@ -347,7 +341,6 @@ class LinkedinPersonLookupBlock(Block):
include_similarity_checks=input_data.include_similarity_checks,
enrich_profile=input_data.enrich_profile,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "lookup_result", lookup_result
except Exception as e:
logger.error(f"Error looking up LinkedIn profile: {str(e)}")
@@ -450,7 +443,6 @@ class LinkedinRoleLookupBlock(Block):
company_name=input_data.company_name,
enrich_profile=input_data.enrich_profile,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "role_lookup_result", role_lookup_result
except Exception as e:
logger.error(f"Error looking up role in company: {str(e)}")
@@ -531,7 +523,6 @@ class GetLinkedinProfilePictureBlock(Block):
credentials=credentials,
linkedin_profile_url=input_data.linkedin_profile_url,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "profile_picture_url", profile_picture
except Exception as e:
logger.error(f"Error getting profile picture: {str(e)}")

View File

@@ -18,7 +18,7 @@ from backend.blocks.fal._auth import (
FalCredentialsInput,
)
from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.model import SchemaField
from backend.util.file import store_media_file
from backend.util.request import ClientResponseError, Requests
from backend.util.type import MediaFileType
@@ -230,7 +230,6 @@ class AIVideoGeneratorBlock(Block):
execution_context=execution_context,
return_format="for_block_output",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "video_url", stored_url
except Exception as e:
error_message = str(e)

View File

@@ -118,7 +118,11 @@ class GoogleMapsSearchBlock(Block):
input_data.radius,
input_data.max_results,
)
self.merge_stats(NodeExecutionStats(output_size=len(places)))
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(places)), provider_cost_type="items"
)
)
for place in places:
yield "place", place

View File

@@ -14,7 +14,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -228,7 +227,6 @@ class IdeogramModelBlock(Block):
image_url=result,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "result", result
async def run_model(

View File

@@ -8,7 +8,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -154,7 +153,6 @@ class AddMemoryBlock(Block, Mem0Base):
messages,
**params,
)
self.merge_stats(NodeExecutionStats(output_size=1))
results = result.get("results", [])
yield "results", results
@@ -257,7 +255,6 @@ class SearchMemoryBlock(Block, Mem0Base):
result: list[dict[str, Any]] = client.search(
input_data.query, version="v2", filters=filters
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "memories", result
except Exception as e:
@@ -343,7 +340,6 @@ class GetAllMemoriesBlock(Block, Mem0Base):
filters=filters,
version="v2",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "memories", memories
@@ -438,7 +434,6 @@ class GetLatestMemoryBlock(Block, Mem0Base):
filters=filters,
version="v2",
)
self.merge_stats(NodeExecutionStats(output_size=1))
if memories:
# Return the latest memory (first in the list as they're sorted by recency)

View File

@@ -10,7 +10,7 @@ from backend.blocks.nvidia._auth import (
NvidiaCredentialsField,
NvidiaCredentialsInput,
)
from backend.data.model import NodeExecutionStats, SchemaField
from backend.data.model import SchemaField
from backend.util.request import Requests
from backend.util.type import MediaFileType
@@ -69,7 +69,6 @@ class NvidiaDeepfakeDetectBlock(Block):
data = response.json()
result = data.get("data", [{}])[0]
self.merge_stats(NodeExecutionStats(output_size=1))
# Get deepfake probability from first bounding box if any
deepfake_prob = 0.0

View File

@@ -17,12 +17,7 @@ from backend.blocks.replicate._auth import (
ReplicateCredentialsInput,
)
from backend.blocks.replicate._helper import ReplicateOutputs, extract_result
from backend.data.model import (
APIKeyCredentials,
CredentialsField,
NodeExecutionStats,
SchemaField,
)
from backend.data.model import APIKeyCredentials, CredentialsField, SchemaField
from backend.util.exceptions import BlockExecutionError, BlockInputError
logger = logging.getLogger(__name__)
@@ -113,7 +108,6 @@ class ReplicateModelBlock(Block):
result = await self.run_model(
model_ref, input_data.model_inputs, credentials.api_key
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "result", result
yield "status", "succeeded"
yield "model_name", input_data.model_name

View File

@@ -16,7 +16,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -186,7 +185,6 @@ class ScreenshotWebPageBlock(Block):
block_chats=input_data.block_chats,
cache=input_data.cache,
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "image", screenshot_data["image"]
except Exception as e:
yield "error", str(e)

View File

@@ -15,7 +15,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -147,7 +146,6 @@ class GetWeatherInformationBlock(Block, GetRequest):
weather_data = await self.get_request(url, json=True)
if "main" in weather_data and "weather" in weather_data:
self.merge_stats(NodeExecutionStats(output_size=1))
yield "temperature", str(weather_data["main"]["temp"])
yield "humidity", str(weather_data["main"]["humidity"])
yield "condition", weather_data["weather"][0]["description"]

View File

@@ -100,7 +100,6 @@ class CreateCampaignBlock(Block):
**kwargs,
) -> BlockOutput:
response = await self.create_campaign(input_data.name, credentials)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "id", response.id
yield "name", response.name
@@ -227,7 +226,12 @@ class AddLeadToCampaignBlock(Block):
response = await self.add_leads_to_campaign(
input_data.campaign_id, input_data.lead_list, credentials
)
self.merge_stats(NodeExecutionStats(output_size=len(input_data.lead_list)))
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.lead_list)),
provider_cost_type="items",
)
)
yield "campaign_id", input_data.campaign_id
yield "upload_count", response.upload_count
@@ -323,7 +327,6 @@ class SaveCampaignSequencesBlock(Block):
response = await self.save_campaign_sequences(
input_data.campaign_id, input_data.sequences, credentials
)
self.merge_stats(NodeExecutionStats(output_size=1))
if response.data:
yield "data", response.data

View File

@@ -15,7 +15,6 @@ from backend.data.model import (
APIKeyCredentials,
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
)
from backend.integrations.providers import ProviderName
@@ -182,7 +181,6 @@ class CreateTalkingAvatarVideoBlock(Block):
execution_context=execution_context,
return_format="for_block_output",
)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "video_url", stored_url
return
elif status_response["status"] == "error":

View File

@@ -105,5 +105,10 @@ class UnrealTextToSpeechBlock(Block):
input_data.text,
input_data.voice_id,
)
self.merge_stats(NodeExecutionStats(output_size=len(input_data.text)))
self.merge_stats(
NodeExecutionStats(
provider_cost=float(len(input_data.text)),
provider_cost_type="characters",
)
)
yield "mp3_url", api_response["OutputUri"]

View File

@@ -19,7 +19,6 @@ from backend.blocks._base import (
from backend.data.model import (
CredentialsField,
CredentialsMetaInput,
NodeExecutionStats,
SchemaField,
UserPasswordCredentials,
)
@@ -171,7 +170,6 @@ class TranscribeYoutubeVideoBlock(Block):
transcript = self.get_transcript(video_id, credentials)
transcript_text = self.format_transcript(transcript=transcript)
self.merge_stats(NodeExecutionStats(output_size=1))
# Only yield after all operations succeed
yield "video_id", video_id
yield "transcript", transcript_text

View File

@@ -21,7 +21,7 @@ from backend.blocks.zerobounce._auth import (
ZeroBounceCredentials,
ZeroBounceCredentialsInput,
)
from backend.data.model import CredentialsField, NodeExecutionStats, SchemaField
from backend.data.model import CredentialsField, SchemaField
class Response(BaseModel):
@@ -177,6 +177,5 @@ class ValidateEmailsBlock(Block):
)
response_model = Response(**response.__dict__)
self.merge_stats(NodeExecutionStats(output_size=1))
yield "response", response_model

View File

@@ -819,6 +819,17 @@ class RefundRequest(BaseModel):
updated_at: datetime
ProviderCostType = Literal[
"cost_usd", # Actual USD cost reported by the provider
"tokens", # LLM token counts (sum of input + output)
"characters", # Per-character billing (TTS providers)
"sandbox_seconds", # Per-second compute billing (e.g. E2B)
"walltime_seconds", # Per-second billing incl. queue/polling
"per_run", # Per-API-call billing with fixed cost
"items", # Per-item billing (lead/organization/result count)
]
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
@@ -839,6 +850,10 @@ class NodeExecutionStats(BaseModel):
extra_cost: int = 0
extra_steps: int = 0
provider_cost: float | None = None
# Type of the provider-reported cost/usage captured above. When set
# by a block, resolve_tracking honors this directly instead of
# guessing from provider name.
provider_cost_type: Optional[ProviderCostType] = None
# Moderation fields
cleared_inputs: Optional[dict[str, list[str]]] = None
cleared_outputs: Optional[dict[str, list[str]]] = None

View File

@@ -32,18 +32,29 @@ def resolve_tracking(
stats: NodeExecutionStats,
input_data: dict[str, Any],
) -> tuple[str, float]:
"""Return (tracking_type, tracking_amount) based on provider billing model."""
# 1. Provider returned actual USD cost (OpenRouter, Exa)
"""Return (tracking_type, tracking_amount) based on provider billing model.
Preference order:
1. Block-declared: if the block set `provider_cost_type` on its stats,
honor it directly (paired with `provider_cost` as the amount).
2. Heuristic fallback: infer from `provider_cost`/token counts, then
from provider name for per-character / per-second billing.
"""
# 1. Block explicitly declared its cost type
if stats.provider_cost_type:
return stats.provider_cost_type, stats.provider_cost or 0.0
# 2. Provider returned actual USD cost (OpenRouter, Exa)
if stats.provider_cost is not None:
return "cost_usd", stats.provider_cost
# 2. LLM providers: track by tokens
# 3. LLM providers: track by tokens
if stats.input_token_count or stats.output_token_count:
return "tokens", float(
(stats.input_token_count or 0) + (stats.output_token_count or 0)
)
# 3. Provider-specific billing models
# 4. Provider-specific billing heuristics
# TTS: billed per character of input text
if provider == "unreal_speech":
@@ -69,7 +80,6 @@ def resolve_tracking(
# Per-request: Google Maps, Ideogram, Nvidia, Apollo, etc.
# All billed per API call - count 1 per block execution.
# output_size captured separately for volume estimation.
return "per_run", 1.0

View File

@@ -80,6 +80,31 @@ class TestResolveTracking:
assert tt == "characters"
assert amt == 9.0
def test_block_declared_cost_type_items(self):
"""Block explicitly setting provider_cost_type='items' short-circuits heuristics."""
stats = self._stats(provider_cost=5.0, provider_cost_type="items")
tt, amt = resolve_tracking("google_maps", stats, {})
assert tt == "items"
assert amt == 5.0
def test_block_declared_cost_type_characters(self):
"""TTS block can declare characters directly, bypassing input_data lookup."""
stats = self._stats(provider_cost=42.0, provider_cost_type="characters")
tt, amt = resolve_tracking("unreal_speech", stats, {})
assert tt == "characters"
assert amt == 42.0
def test_block_declared_cost_type_wins_over_tokens(self):
"""provider_cost_type takes precedence over token-based heuristic."""
stats = self._stats(
provider_cost=1.0,
provider_cost_type="per_run",
input_token_count=500,
)
tt, amt = resolve_tracking("openai", stats, {})
assert tt == "per_run"
assert amt == 1.0
def test_e2b_returns_sandbox_seconds(self):
stats = self._stats(walltime=45.123)
tt, amt = resolve_tracking("e2b", stats, {})