Merge branch 'fix/orchestrator-per-iteration-cost' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs

This commit is contained in:
majdyz
2026-04-13 08:01:50 +00:00
7 changed files with 1393 additions and 28 deletions

View File

@@ -420,6 +420,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
_optimized_description: ClassVar[str | None] = None
def extra_credit_charges(self, execution_stats: "NodeExecutionStats") -> int:
"""Return extra credits to charge after this block run completes.
Called by the executor after a block finishes with COMPLETED status.
The return value is the number of additional base-cost credits to
charge beyond the single credit already collected by ``_charge_usage``
at the start of execution. Defaults to 0 (no extra charges).
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
calls within one run and should be billed per call.
"""
return 0
def __init__(
self,
id: str = "",

View File

@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
from backend.data.model import NodeExecutionStats, SchemaField
from backend.util import json
from backend.util.clients import get_database_manager_async_client
from backend.util.exceptions import InsufficientBalanceError
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
from backend.util.security import SENSITIVE_FIELD_NAMES
from backend.util.tool_call_loop import (
@@ -364,10 +365,25 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
class OrchestratorBlock(Block):
"""A block that uses a language model to orchestrate tool calls.
Supports both single-shot and iterative agent mode execution.
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
(IBE) must always re-raise through every ``except`` block in this class.
Swallowing IBE would let the agent loop continue with unpaid work. Every
exception handler that catches ``Exception`` includes an explicit IBE
re-raise carve-out for this reason.
"""
A block that uses a language model to orchestrate tool calls, supporting both
single-shot and iterative agent mode execution.
"""
def extra_credit_charges(self, execution_stats: NodeExecutionStats) -> int:
"""Charge one extra base credit per LLM call beyond the first.
In agent mode each iteration makes one LLM call. The first is already
covered by _charge_usage(); this returns the number of additional
credits so the executor can bill the remaining calls post-completion.
"""
return max(0, execution_stats.llm_call_count - 1)
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
@@ -1077,7 +1093,10 @@ class OrchestratorBlock(Block):
input_data=input_value,
)
assert node_exec_result is not None, "node_exec_result should not be None"
if node_exec_result is None:
raise RuntimeError(
f"upsert_execution_input returned None for node {sink_node_id}"
)
# Create NodeExecutionEntry for execution manager
node_exec_entry = NodeExecutionEntry(
@@ -1112,15 +1131,79 @@ class OrchestratorBlock(Block):
task=node_exec_future,
)
# Execute the node directly since we're in the Orchestrator context
node_exec_future.set_result(
await execution_processor.on_node_execution(
# Execute the node directly since we're in the Orchestrator context.
# Wrap in try/except so the future is always resolved, even on
# error — an unresolved Future would block anything awaiting it.
#
# on_node_execution is decorated with @async_error_logged(swallow=True),
# which catches BaseException and returns None rather than raising.
# Treat a None return as a failure: set_exception so the future
# carries an error state rather than a None result, and return an
# error response so the LLM knows the tool failed.
try:
tool_node_stats = await execution_processor.on_node_execution(
node_exec=node_exec_entry,
node_exec_progress=node_exec_progress,
nodes_input_masks=None,
graph_stats_pair=graph_stats_pair,
)
)
if tool_node_stats is None:
nil_err = RuntimeError(
f"on_node_execution returned None for node {sink_node_id} "
"(error was swallowed by @async_error_logged)"
)
node_exec_future.set_exception(nil_err)
resp = _create_tool_response(
tool_call.id,
"Tool execution returned no result",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
node_exec_future.set_result(tool_node_stats)
except Exception as exec_err:
node_exec_future.set_exception(exec_err)
raise
# Charge user credits AFTER successful tool execution. Tools
# spawned by the orchestrator bypass the main execution queue
# (where _charge_usage is called), so we must charge here to
# avoid free tool execution. Charging post-completion (vs.
# pre-execution) avoids billing users for failed tool calls.
# Skipped for dry runs.
#
# `error is None` intentionally excludes both Exception and
# BaseException subclasses (e.g. CancelledError) so cancelled
# or terminated tool runs are not billed.
#
# Billing errors (including non-balance exceptions) are kept
# in a separate try/except so they are never silently swallowed
# by the generic tool-error handler below.
if (
not execution_params.execution_context.dry_run
and tool_node_stats.error is None
):
try:
tool_cost, _ = await execution_processor.charge_node_usage(
node_exec_entry,
)
except InsufficientBalanceError:
# IBE must propagate — see OrchestratorBlock class docstring.
raise
except Exception:
# Non-billing charge failures (DB outage, network, etc.)
# must NOT propagate to the outer except handler because
# the tool itself succeeded. Re-raising would mark the
# tool as failed (_is_error=True), causing the LLM to
# retry side-effectful operations. Log and continue.
logger.exception(
"Unexpected error charging for tool node %s; "
"tool execution was successful",
sink_node_id,
)
tool_cost = 0
if tool_cost > 0:
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
# Get outputs from database after execution completes using database manager client
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
@@ -1133,18 +1216,26 @@ class OrchestratorBlock(Block):
if node_outputs
else "Tool executed successfully"
)
return _create_tool_response(
resp = _create_tool_response(
tool_call.id, tool_response_content, responses_api=responses_api
)
resp["_is_error"] = False
return resp
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.warning("Tool execution with manager failed: %s", e)
# Return error response
return _create_tool_response(
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
# Return a generic error to the LLM — internal exception messages
# may contain server paths, DB details, or infrastructure info.
resp = _create_tool_response(
tool_call.id,
f"Tool execution failed: {e}",
"Tool execution failed due to an internal error",
responses_api=responses_api,
)
resp["_is_error"] = True
return resp
async def _agent_mode_llm_caller(
self,
@@ -1244,13 +1335,16 @@ class OrchestratorBlock(Block):
content = str(raw_content)
else:
content = "Tool executed successfully"
tool_failed = content.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return ToolCallResult(
tool_call_id=tool_call.id,
tool_name=tool_call.name,
content=content,
is_error=tool_failed,
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("Tool execution failed: %s", e)
return ToolCallResult(
@@ -1370,9 +1464,13 @@ class OrchestratorBlock(Block):
"arguments": tc.arguments,
},
)
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
# Catch all errors (validation, network, API) so that the block
# surfaces them as user-visible output instead of crashing.
# Catch all OTHER errors (validation, network, API) so that
# the block surfaces them as user-visible output instead of
# crashing.
yield "error", str(e)
return
@@ -1450,11 +1548,14 @@ class OrchestratorBlock(Block):
text = content
else:
text = json.dumps(content)
tool_failed = text.startswith("Tool execution failed:")
tool_failed = result.get("_is_error", True)
return {
"content": [{"type": "text", "text": text}],
"isError": tool_failed,
}
except InsufficientBalanceError:
# IBE must propagate — see class docstring.
raise
except Exception as e:
logger.error("SDK tool execution failed: %s", e)
return {
@@ -1733,11 +1834,15 @@ class OrchestratorBlock(Block):
await pending_task
except (asyncio.CancelledError, StopAsyncIteration):
pass
except InsufficientBalanceError:
# IBE must propagate — see class docstring. The `finally`
# block below still runs and records partial token usage.
raise
except Exception as e:
# Surface SDK errors as user-visible output instead of crashing,
# consistent with _execute_tools_agent_mode error handling.
# Don't return yet — fall through to merge_stats below so
# partial token usage is always recorded.
# Surface OTHER SDK errors as user-visible output instead
# of crashing, consistent with _execute_tools_agent_mode
# error handling. Don't return yet — fall through to
# merge_stats below so partial token usage is always recorded.
sdk_error = e
finally:
# Always record usage stats, even on error. The SDK may have

View File

@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
mock_execution_processor.on_node_execution = AsyncMock(
return_value=mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Returns (cost, remaining_balance). Must be AsyncMock because it is
# an async method and is directly awaited in _execute_single_tool_with_manager.
# Use a non-zero cost so the merge_stats branch is exercised.
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
# Verify charge_node_usage was actually called for the successful
# tool execution — this guards against regressions where the
# post-execution tool charging is accidentally removed.
assert mock_execution_processor.charge_node_usage.call_count == 1
@pytest.mark.asyncio
async def test_orchestrator_traditional_mode_default():

View File

@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
mock_execution_processor.on_node_execution.return_value = (
mock_node_stats
)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would
# return a non-awaitable tuple and TypeError out, then be
# silently swallowed by the orchestrator's catch-all.
mock_execution_processor.charge_node_usage = AsyncMock(
return_value=(0, 0)
)
async for output_name, output_value in block.run(
input_data,

View File

@@ -0,0 +1,973 @@
"""Tests for OrchestratorBlock per-iteration cost charging.
The OrchestratorBlock in agent mode makes multiple LLM calls in a single
node execution. The executor uses ``Block.extra_credit_charges`` to detect
this and charge ``base_cost * (llm_call_count - 1)`` extra credits after
the block completes.
"""
import threading
from unittest.mock import AsyncMock, MagicMock
import pytest
from backend.blocks._base import Block
from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock
from backend.data.execution import ExecutionContext, ExecutionStatus
from backend.data.model import NodeExecutionStats
from backend.executor import manager
from backend.util.exceptions import InsufficientBalanceError
# ── extra_credit_charges hook ────────────────────────────────────────
class _NoOpBlock(Block):
"""Minimal concrete Block subclass that does not override extra_credit_charges."""
def __init__(self):
super().__init__(id="noop-block", description="No-op test block")
def run(self, input_data, **kwargs): # type: ignore[override]
yield "out", {}
class TestExtraCreditCharges:
"""OrchestratorBlock opts into per-LLM-call billing via extra_credit_charges."""
def test_orchestrator_returns_nonzero_for_multiple_calls(self):
block = OrchestratorBlock()
stats = NodeExecutionStats(llm_call_count=3)
assert block.extra_credit_charges(stats) == 2
def test_orchestrator_returns_zero_for_single_call(self):
block = OrchestratorBlock()
stats = NodeExecutionStats(llm_call_count=1)
assert block.extra_credit_charges(stats) == 0
def test_orchestrator_returns_zero_for_zero_calls(self):
block = OrchestratorBlock()
stats = NodeExecutionStats(llm_call_count=0)
assert block.extra_credit_charges(stats) == 0
def test_default_block_returns_zero(self):
"""A block that does not override extra_credit_charges returns 0."""
block = _NoOpBlock()
stats = NodeExecutionStats(llm_call_count=10)
assert block.extra_credit_charges(stats) == 0
# ── charge_extra_iterations math ───────────────────────────────────
@pytest.fixture()
def fake_node_exec():
node_exec = MagicMock()
node_exec.user_id = "u"
node_exec.graph_exec_id = "g"
node_exec.graph_id = "g"
node_exec.node_exec_id = "ne"
node_exec.node_id = "n"
node_exec.block_id = "b"
node_exec.inputs = {}
return node_exec
@pytest.fixture()
def patched_processor(monkeypatch):
"""ExecutionProcessor with stubbed db client / block lookup helpers.
Returns the processor and a list of credit amounts spent so tests can
assert on what was charged.
Note: ``ExecutionProcessor.__new__()`` bypasses ``__init__`` — if
``__init__`` gains required state in the future this fixture will need
updating.
"""
spent: list[int] = []
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
spent.append(cost)
return 1000 # remaining balance
fake_block = MagicMock()
fake_block.name = "FakeBlock"
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
monkeypatch.setattr(
manager,
"block_usage_cost",
lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}),
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
return proc, spent
class TestChargeExtraIterations:
@pytest.mark.asyncio
async def test_zero_extra_iterations_charges_nothing(
self, patched_processor, fake_node_exec
):
proc, spent = patched_processor
cost, balance = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=0
)
assert cost == 0
assert balance == 0
assert spent == []
@pytest.mark.asyncio
async def test_extra_iterations_multiplies_base_cost(
self, patched_processor, fake_node_exec
):
proc, spent = patched_processor
cost, balance = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=4
)
assert cost == 40 # 4 × 10
assert balance == 1000
assert spent == [40]
@pytest.mark.asyncio
async def test_negative_extra_iterations_charges_nothing(
self, patched_processor, fake_node_exec
):
proc, spent = patched_processor
cost, balance = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=-1
)
assert cost == 0
assert balance == 0
assert spent == []
@pytest.mark.asyncio
async def test_capped_at_max(self, monkeypatch, fake_node_exec):
"""Runaway llm_call_count is capped at _MAX_EXTRA_ITERATIONS."""
spent: list[int] = []
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
spent.append(cost)
return 1000
fake_block = MagicMock()
fake_block.name = "FakeBlock"
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
monkeypatch.setattr(
manager,
"block_usage_cost",
lambda block, input_data, **_kw: (10, {}),
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cap = manager.ExecutionProcessor._MAX_EXTRA_ITERATIONS
cost, _ = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=cap * 100
)
# Charged at most cap × 10
assert cost == cap * 10
assert spent == [cap * 10]
@pytest.mark.asyncio
async def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec):
spent: list[int] = []
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
spent.append(cost)
return 0
fake_block = MagicMock()
fake_block.name = "FakeBlock"
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
monkeypatch.setattr(
manager, "block_usage_cost", lambda block, input_data, **_kw: (0, {})
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cost, balance = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=4
)
assert cost == 0
assert balance == 0
assert spent == []
@pytest.mark.asyncio
async def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec):
spent: list[int] = []
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
spent.append(cost)
return 0
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: None)
monkeypatch.setattr(
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cost, balance = await proc.charge_extra_iterations(
fake_node_exec, extra_iterations=3
)
assert cost == 0
assert balance == 0
assert spent == []
@pytest.mark.asyncio
async def test_propagates_insufficient_balance_error(
self, monkeypatch, fake_node_exec
):
"""Out-of-credits errors must propagate, not be silently swallowed."""
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
raise InsufficientBalanceError(
user_id=user_id,
message="Insufficient balance",
balance=0,
amount=cost,
)
fake_block = MagicMock()
fake_block.name = "FakeBlock"
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
monkeypatch.setattr(
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
with pytest.raises(InsufficientBalanceError):
await proc.charge_extra_iterations(fake_node_exec, extra_iterations=4)
# ── charge_node_usage ──────────────────────────────────────────────
class TestChargeNodeUsage:
"""charge_node_usage delegates to _charge_usage with execution_count=0."""
@pytest.mark.asyncio
async def test_delegates_with_zero_execution_count(
self, monkeypatch, fake_node_exec
):
"""Nested tool charges should NOT inflate the per-execution counter."""
captured: dict = {}
def fake_charge_usage(self, node_exec, execution_count):
captured["execution_count"] = execution_count
captured["node_exec"] = node_exec
return (5, 100)
def fake_handle_low_balance(
self, db_client, user_id, current_balance, transaction_cost
):
pass
monkeypatch.setattr(
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
)
monkeypatch.setattr(
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
)
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cost, balance = await proc.charge_node_usage(fake_node_exec)
assert cost == 5
assert balance == 100
assert captured["execution_count"] == 0
@pytest.mark.asyncio
async def test_calls_handle_low_balance_when_cost_nonzero(
self, monkeypatch, fake_node_exec
):
"""charge_node_usage should call _handle_low_balance when total_cost > 0."""
low_balance_calls: list[dict] = []
def fake_charge_usage(self, node_exec, execution_count):
return (10, 50)
def fake_handle_low_balance(
self, db_client, user_id, current_balance, transaction_cost
):
low_balance_calls.append(
{
"user_id": user_id,
"current_balance": current_balance,
"transaction_cost": transaction_cost,
}
)
monkeypatch.setattr(
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
)
monkeypatch.setattr(
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
)
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cost, balance = await proc.charge_node_usage(fake_node_exec)
assert cost == 10
assert balance == 50
assert len(low_balance_calls) == 1
assert low_balance_calls[0]["user_id"] == "u"
assert low_balance_calls[0]["current_balance"] == 50
assert low_balance_calls[0]["transaction_cost"] == 10
@pytest.mark.asyncio
async def test_skips_handle_low_balance_when_cost_zero(
self, monkeypatch, fake_node_exec
):
"""charge_node_usage should NOT call _handle_low_balance when cost is 0."""
low_balance_calls: list = []
def fake_charge_usage(self, node_exec, execution_count):
return (0, 200)
def fake_handle_low_balance(
self, db_client, user_id, current_balance, transaction_cost
):
low_balance_calls.append(True)
monkeypatch.setattr(
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
)
monkeypatch.setattr(
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
)
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
cost, balance = await proc.charge_node_usage(fake_node_exec)
assert cost == 0
assert low_balance_calls == []
# ── on_node_execution charging gate ────────────────────────────────
class _FakeNode:
"""Minimal stand-in for a ``Node`` object with a block attribute."""
def __init__(self, extra_charges: int = 0, block_name: str = "FakeBlock"):
self.block = MagicMock()
self.block.name = block_name
self.block.extra_credit_charges = MagicMock(return_value=extra_charges)
class _FakeExecContext:
def __init__(self, dry_run: bool = False):
self.dry_run = dry_run
def _make_node_exec(dry_run: bool = False) -> MagicMock:
"""Build a NodeExecutionEntry-like mock for on_node_execution tests."""
ne = MagicMock()
ne.user_id = "u"
ne.graph_id = "g"
ne.graph_exec_id = "ge"
ne.node_id = "n"
ne.node_exec_id = "ne"
ne.block_id = "b"
ne.inputs = {}
ne.execution_context = _FakeExecContext(dry_run=dry_run)
return ne
@pytest.fixture()
def gated_processor(monkeypatch):
"""ExecutionProcessor with on_node_execution's downstream calls stubbed.
Lets tests flip the gate conditions (status, extra_credit_charges result,
llm_call_count, dry_run) and observe whether charge_extra_iterations
was called.
"""
calls: dict[str, list] = {
"charge_extra_iterations": [],
"handle_low_balance": [],
"handle_insufficient_funds_notif": [],
}
# Stub node lookup + DB client so the wrapper doesn't touch real infra.
fake_db = MagicMock()
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2))
monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db)
monkeypatch.setattr(manager, "get_db_client", lambda: fake_db)
# get_block is called by LogMetadata construction in on_node_execution.
monkeypatch.setattr(
manager,
"get_block",
lambda block_id: MagicMock(name="FakeBlock"),
)
# Persistence + cost logging are not under test here.
monkeypatch.setattr(
manager,
"async_update_node_execution_status",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
manager,
"async_update_graph_execution_state",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
manager,
"log_system_credential_cost",
AsyncMock(return_value=None),
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
# Control the status returned by the inner execution call.
inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3}
async def fake_inner(
self,
*,
node,
node_exec,
node_exec_progress,
stats,
db_client,
log_metadata,
nodes_input_masks=None,
nodes_to_skip=None,
):
stats.llm_call_count = inner_result["llm_call_count"]
return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"]
monkeypatch.setattr(
manager.ExecutionProcessor,
"_on_node_execution",
fake_inner,
)
async def fake_charge_extra(self, node_exec, extra_iterations):
calls["charge_extra_iterations"].append(extra_iterations)
return (extra_iterations * 10, 500)
monkeypatch.setattr(
manager.ExecutionProcessor,
"charge_extra_iterations",
fake_charge_extra,
)
def fake_low_balance(self, db_client, user_id, current_balance, transaction_cost):
calls["handle_low_balance"].append(
{
"user_id": user_id,
"current_balance": current_balance,
"transaction_cost": transaction_cost,
}
)
monkeypatch.setattr(
manager.ExecutionProcessor,
"_handle_low_balance",
fake_low_balance,
)
def fake_notif(self, db_client, user_id, graph_id, e):
calls["handle_insufficient_funds_notif"].append(
{"user_id": user_id, "graph_id": graph_id, "error": e}
)
monkeypatch.setattr(
manager.ExecutionProcessor,
"_handle_insufficient_funds_notif",
fake_notif,
)
return proc, calls, inner_result, fake_db, NodeExecutionStats
@pytest.mark.asyncio
async def test_on_node_execution_charges_extra_iterations_when_gate_passes(
gated_processor,
):
"""COMPLETED + extra_credit_charges > 0 + not dry_run → charged."""
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 3 # → extra_charges = 2
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2))
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == [2]
# _handle_low_balance must be called with the remaining balance returned by
# charge_extra_iterations (500) so users are alerted when balance drops low.
assert len(calls["handle_low_balance"]) == 1
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_status_not_completed(gated_processor):
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.FAILED
inner["llm_call_count"] = 5
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4))
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_extra_charges_zero(gated_processor):
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 5
# Block returns 0 extra charges (base class default)
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0))
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_dry_run(gated_processor):
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 5
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4))
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=True),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_insufficient_balance_records_error_and_notifies(
monkeypatch,
gated_processor,
):
"""When extra-iteration charging fails with InsufficientBalanceError:
- the run still reports COMPLETED (the work is already done)
- execution_stats.error is NOT set (would flip node_error_count and
leak balance amounts into persisted node_stats — see manager.py
comment in the IBE handler)
- _handle_insufficient_funds_notif is called so the user is notified
- the structured ERROR log is the alerting hook
"""
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 4
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3))
async def raise_ibe(self, node_exec, extra_iterations):
raise InsufficientBalanceError(
user_id=node_exec.user_id,
message="Insufficient balance",
balance=0,
amount=extra_iterations * 10,
)
monkeypatch.setattr(
manager.ExecutionProcessor, "charge_extra_iterations", raise_ibe
)
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
result_stats = await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
# error stays None — node ran to completion, only the post-hoc
# charge failed. Setting .error would (a) flip node_error_count++
# creating an "errored COMPLETED node" inconsistency, and (b) leak
# balance amounts into persisted node_stats.
assert result_stats.error is None
# User notification fired.
assert len(calls["handle_insufficient_funds_notif"]) == 1
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
# ── Orchestrator _execute_single_tool_with_manager charging gates ──
async def _run_tool_exec_with_stats(
*,
dry_run: bool,
tool_stats_error,
charge_node_usage_mock=None,
):
"""Invoke _execute_single_tool_with_manager against fully mocked deps
and return (charge_call_count, merge_stats_calls).
Used to prove the dry_run and error guards around charge_node_usage
behave as documented, and that InsufficientBalanceError propagates.
"""
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
block = OrchestratorBlock()
# Mocked async DB client used inside orchestrator.
mock_db_client = AsyncMock()
mock_target_node = MagicMock()
mock_target_node.block_id = "test-block-id"
mock_target_node.input_default = {}
mock_db_client.get_node.return_value = mock_target_node
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
{"query": "t"},
)
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
# ExecutionProcessor mock: on_node_execution returns supplied error.
mock_processor = AsyncMock()
mock_processor.running_node_execution = defaultdict(MagicMock)
mock_processor.execution_stats = MagicMock()
mock_processor.execution_stats_lock = threading.Lock()
mock_node_stats = MagicMock()
mock_node_stats.error = tool_stats_error
mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats)
mock_processor.charge_node_usage = charge_node_usage_mock or AsyncMock(
return_value=(10, 990)
)
# Build a tool_info shaped like _build_tool_info_from_args output.
tool_call = MagicMock()
tool_call.id = "call-1"
tool_call.name = "search_keywords"
tool_call.arguments = '{"query":"t"}'
tool_def = {
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
}
tool_info = OrchestratorBlock._build_tool_info_from_args(
tool_call_id="call-1",
tool_name="search_keywords",
tool_args={"query": "t"},
tool_def=tool_def,
)
exec_params = ExecutionParams(
user_id="u",
graph_id="g",
node_id="n",
graph_version=1,
graph_exec_id="ge",
node_exec_id="ne",
execution_context=ExecutionContext(
human_in_the_loop_safe_mode=False, dry_run=dry_run
),
)
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
return_value=mock_db_client,
):
try:
await block._execute_single_tool_with_manager(
tool_info, exec_params, mock_processor, responses_api=False
)
raised = None
except Exception as e:
raised = e
return mock_processor.charge_node_usage, raised
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_dry_run():
"""dry_run=True → charge_node_usage is NOT called."""
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=True, tool_stats_error=None
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_failed_tool():
"""tool_node_stats.error is an Exception → charge_node_usage NOT called."""
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=False, tool_stats_error=RuntimeError("tool blew up")
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_cancelled_tool():
"""Cancellation (BaseException subclass) → charge_node_usage NOT called.
Guards the fix for sentry's BaseException concern: the old
`isinstance(error, Exception)` check would have treated CancelledError
as "no error" and billed the user for a terminated run.
"""
import asyncio as _asyncio
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=False, tool_stats_error=_asyncio.CancelledError()
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_insufficient_balance_propagates():
"""InsufficientBalanceError from charge_node_usage must propagate out.
If this leaked into a ToolCallResult the LLM loop would keep running
with 'tool failed' errors and the user would get unpaid work.
"""
raising_charge = AsyncMock(
side_effect=InsufficientBalanceError(
user_id="u", message="nope", balance=0, amount=10
)
)
_, raised = await _run_tool_exec_with_stats(
dry_run=False,
tool_stats_error=None,
charge_node_usage_mock=raising_charge,
)
assert isinstance(raised, InsufficientBalanceError)
# ── on_node_execution FAILED + InsufficientBalanceError notification ──
@pytest.mark.asyncio
async def test_on_node_execution_failed_ibe_sends_notification(
monkeypatch,
gated_processor,
):
"""When status == FAILED and execution_stats.error is InsufficientBalanceError,
_handle_insufficient_funds_notif must be called.
This path fires when a nested tool charge inside the orchestrator raises
InsufficientBalanceError, which propagates out of the block's run() generator
and is caught by _on_node_execution's broad except, setting status=FAILED and
execution_stats.error=IBE. on_node_execution's post-execution block then
sends the user notification so they understand why the run stopped.
"""
proc, calls, inner, fake_db, NodeExecutionStats = gated_processor
ibe = InsufficientBalanceError(
user_id="u",
message="Insufficient balance",
balance=0,
amount=30,
)
# Simulate _on_node_execution returning FAILED with IBE in stats.error.
async def fake_inner_failed(
self,
*,
node,
node_exec,
node_exec_progress,
stats,
db_client,
log_metadata,
nodes_input_masks=None,
nodes_to_skip=None,
):
stats.error = ibe
return MagicMock(wall_time=0.1, cpu_time=0.1), ExecutionStatus.FAILED
monkeypatch.setattr(
manager.ExecutionProcessor,
"_on_node_execution",
fake_inner_failed,
)
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0))
stats_pair = (
MagicMock(
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
),
threading.Lock(),
)
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
# The notification must have fired so the user knows why their run stopped.
assert len(calls["handle_insufficient_funds_notif"]) == 1
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
# charge_extra_iterations must NOT be called — status is FAILED.
assert calls["charge_extra_iterations"] == []
# ── Billing leak: non-IBE exception during extra-iteration charging ──
@pytest.mark.asyncio
async def test_on_node_execution_non_ibe_billing_failure_keeps_completed(
monkeypatch,
gated_processor,
):
"""When charge_extra_iterations raises a non-IBE exception (e.g. DB outage):
- execution_stats.error stays None (node ran to completion)
- status stays COMPLETED (work already done)
- the billing_leak error is logged but does not corrupt execution_stats
"""
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 4
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3))
async def raise_conn_error(self, node_exec, extra_iterations):
raise ConnectionError("DB connection lost")
monkeypatch.setattr(
manager.ExecutionProcessor, "charge_extra_iterations", raise_conn_error
)
stats_pair = (
MagicMock(
node_count=0,
nodes_cputime=0,
nodes_walltime=0,
cost=0,
node_error_count=0,
),
threading.Lock(),
)
result_stats = await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
# error stays None — node completed, only billing failed.
assert result_stats.error is None
# No notification was sent (only IBE triggers notification).
assert len(calls["handle_insufficient_funds_notif"]) == 0
# ── _charge_usage with execution_count=0 ──
class TestChargeUsageZeroExecutionCount:
"""Verify _charge_usage(node_exec, 0) does not invoke execution_usage_cost."""
def test_execution_count_zero_skips_execution_tier(self, monkeypatch):
"""_charge_usage with execution_count=0 must not call execution_usage_cost."""
execution_tier_called = []
def fake_execution_usage_cost(count):
execution_tier_called.append(count)
return (100, count)
spent: list[int] = []
class FakeDb:
def spend_credits(self, *, user_id, cost, metadata):
spent.append(cost)
return 500
fake_block = MagicMock()
fake_block.name = "FakeBlock"
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
monkeypatch.setattr(
manager,
"block_usage_cost",
lambda block, input_data, **_kw: (10, {}),
)
monkeypatch.setattr(manager, "execution_usage_cost", fake_execution_usage_cost)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
ne = MagicMock()
ne.user_id = "u"
ne.graph_exec_id = "ge"
ne.graph_id = "g"
ne.node_exec_id = "ne"
ne.node_id = "n"
ne.block_id = "b"
ne.inputs = {}
total_cost, remaining = proc._charge_usage(ne, 0)
assert total_cost == 10 # block cost only
assert remaining == 500
assert spent == [10]
# execution_usage_cost must NOT have been called
assert execution_tier_called == []

View File

@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
ep.execution_stats_lock = threading.Lock()
ns = MagicMock(error=None)
ep.on_node_execution = AsyncMock(return_value=ns)
# Mock charge_node_usage (called after successful tool execution).
# Must be AsyncMock because it is async and is awaited in
# _execute_single_tool_with_manager — a plain MagicMock would return a
# non-awaitable tuple and TypeError out, then be silently swallowed by
# the orchestrator's catch-all.
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
block, "_create_tool_node_signatures", return_value=tool_sigs

View File

@@ -19,7 +19,7 @@ from sentry_sdk.api import flush as _sentry_flush
from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
from backend.blocks import get_block
from backend.blocks._base import BlockSchema
from backend.blocks._base import Block, BlockSchema
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentOutputBlock
from backend.blocks.mcp.block import MCPToolBlock
@@ -681,6 +681,10 @@ class ExecutionProcessor:
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
await self._handle_post_execution_billing(
node, node_exec, execution_stats, status, log_metadata
)
graph_stats, graph_stats_lock = graph_stats_pair
with graph_stats_lock:
graph_stats.node_count += 1 + execution_stats.extra_steps
@@ -716,8 +720,121 @@ class ExecutionProcessor:
db_client=db_client,
)
# If the node failed because a nested tool charge raised IBE,
# send the user notification so they understand why the run stopped.
if status == ExecutionStatus.FAILED and isinstance(
execution_stats.error, InsufficientBalanceError
):
await self._try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
execution_stats.error,
log_metadata,
)
return execution_stats
async def _try_send_insufficient_funds_notif(
self,
user_id: str,
graph_id: str,
error: InsufficientBalanceError,
log_metadata: LogMetadata,
) -> None:
"""Send an insufficient-funds notification, swallowing failures."""
try:
await asyncio.to_thread(
self._handle_insufficient_funds_notif,
get_db_client(),
user_id,
graph_id,
error,
)
except Exception as notif_error: # pragma: no cover
log_metadata.warning(
f"Failed to send insufficient funds notification: {notif_error}"
)
async def _handle_post_execution_billing(
self,
node: Node,
node_exec: NodeExecutionEntry,
execution_stats: NodeExecutionStats,
status: ExecutionStatus,
log_metadata: LogMetadata,
) -> None:
"""Charge extra iterations for blocks that opt into per-LLM-call billing.
The first LLM call is already covered by ``_charge_usage()``; each
additional call costs another ``base_cost``. Skipped for dry runs and
failed runs.
InsufficientBalanceError here is a post-hoc billing leak: the work is
already done but the user can no longer pay. The run stays COMPLETED and
the error is logged with ``billing_leak: True`` for alerting.
"""
extra_iterations = (
node.block.extra_credit_charges(execution_stats)
if status == ExecutionStatus.COMPLETED
and not node_exec.execution_context.dry_run
else 0
)
if extra_iterations <= 0:
return
try:
extra_cost, remaining_balance = await self.charge_extra_iterations(
node_exec,
extra_iterations,
)
if extra_cost > 0:
execution_stats.extra_cost += extra_cost
await asyncio.to_thread(
self._handle_low_balance,
get_db_client(),
node_exec.user_id,
remaining_balance,
extra_cost,
)
except InsufficientBalanceError as e:
log_metadata.error(
"billing_leak: insufficient balance after "
f"{node.block.name} completed {extra_iterations} "
f"extra iterations",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_iterations": extra_iterations,
"error": str(e),
},
)
# Do NOT set execution_stats.error — the node ran to completion,
# only the post-hoc charge failed. See class-level billing-leak
# contract documentation.
await self._try_send_insufficient_funds_notif(
node_exec.user_id,
node_exec.graph_id,
e,
log_metadata,
)
except Exception as e:
log_metadata.error(
f"billing_leak: failed to charge extra iterations "
f"for {node.block.name}",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_iterations": extra_iterations,
"error_type": type(e).__name__,
"error": str(e),
},
exc_info=True,
)
@async_time_measured
async def _on_node_execution(
self,
@@ -944,6 +1061,27 @@ class ExecutionProcessor:
stats=exec_stats,
)
def _resolve_block_cost(
self,
node_exec: NodeExecutionEntry,
) -> tuple[Block | None, int, dict]:
"""Look up the block and compute its base usage cost for an exec.
Shared by :meth:`_charge_usage` and :meth:`charge_extra_iterations`
so the (get_block, block_usage_cost) lookup lives in exactly one
place. Returns ``(block, cost, matching_filter)``. ``block`` is
``None`` if the block id can't be resolved — callers should treat
that as "nothing to charge".
"""
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return None, 0, {}
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
return block, cost, matching_filter
def _charge_usage(
self,
node_exec: NodeExecutionEntry,
@@ -952,14 +1090,10 @@ class ExecutionProcessor:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
block, cost, matching_filter = self._resolve_block_cost(node_exec)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
@@ -977,7 +1111,13 @@ class ExecutionProcessor:
)
total_cost += cost
cost, usage_count = execution_usage_cost(execution_count)
# execution_count=0 is used by charge_node_usage for nested tool calls
# which must not be pushed into higher execution-count tiers.
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
# so skip it entirely when execution_count is 0.
cost, usage_count = (
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
@@ -996,6 +1136,116 @@ class ExecutionProcessor:
return total_cost, remaining_balance
# Hard cap on the multiplier passed to charge_extra_iterations to
# protect against a corrupted llm_call_count draining a user's balance.
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
# 200 leaves headroom while preventing runaway charges.
_MAX_EXTRA_ITERATIONS = 200
def _charge_extra_iterations_sync(
self,
node_exec: NodeExecutionEntry,
capped_iterations: int,
) -> tuple[int, int]:
"""Synchronous implementation — runs in a thread-pool worker.
Called only from :meth:`charge_extra_iterations`. Do not call
directly from async code.
Note: ``_resolve_block_cost`` is called again here (rather than
reusing the result from ``_charge_usage`` at the start of execution)
because the two calls happen in separate thread-pool workers and
sharing mutable state across workers would require locks. The block
config is immutable during a run, so the repeated lookup is safe and
produces the same cost; the only overhead is an extra registry lookup.
"""
db_client = get_db_client()
block, cost, matching_filter = self._resolve_block_cost(node_exec)
if not block or cost <= 0:
return 0, 0
total_extra_cost = cost * capped_iterations
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
cost=total_extra_cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input={
**matching_filter,
"extra_iterations": capped_iterations,
},
reason=(
f"Extra agent-mode iterations for {block.name} "
f"({capped_iterations} additional LLM calls)"
),
),
)
return total_extra_cost, remaining_balance
async def charge_extra_iterations(
self,
node_exec: NodeExecutionEntry,
extra_iterations: int,
) -> tuple[int, int]:
"""Charge a block extra iterations beyond the initial run.
Used by agent-mode blocks (e.g. OrchestratorBlock) that make
multiple LLM calls within a single node execution. The first
iteration is already charged by :meth:`_charge_usage`; this
method charges *extra_iterations* additional copies of the
block's base cost.
Returns ``(total_extra_cost, remaining_balance)``. May raise
``InsufficientBalanceError`` if the user can't afford the charge.
"""
if extra_iterations <= 0:
return 0, 0
# Cap to protect against a corrupted llm_call_count.
capped = min(extra_iterations, self._MAX_EXTRA_ITERATIONS)
return await asyncio.to_thread(
self._charge_extra_iterations_sync, node_exec, capped
)
def _charge_and_check_balance(
self,
node_exec: NodeExecutionEntry,
) -> tuple[int, int]:
"""Charge usage and check low balance in a single thread-pool worker.
Combines ``_charge_usage`` and ``_handle_low_balance`` to avoid
dispatching two thread-pool calls per tool execution.
"""
total_cost, remaining = self._charge_usage(node_exec, 0)
if total_cost > 0:
self._handle_low_balance(
get_db_client(), node_exec.user_id, remaining, total_cost
)
return total_cost, remaining
async def charge_node_usage(
self,
node_exec: NodeExecutionEntry,
) -> tuple[int, int]:
"""Charge a single node execution to the user.
Public async wrapper around :meth:`_charge_usage` for blocks (e.g. the
OrchestratorBlock) that spawn nested node executions outside the
main queue and therefore need to charge them explicitly.
Also handles low-balance notification so callers don't need to touch
private methods directly.
Note: this **does not** increment the global execution counter
(``increment_execution_count``). Nested tool executions are
sub-steps of a single block run from the user's perspective and
should not push them into higher per-execution cost tiers.
"""
return await asyncio.to_thread(self._charge_and_check_balance, node_exec)
@time_measured
def _on_graph_execution(
self,