fix(orchestrator): propagate billing errors and close leak windows

Round 3 review fixes on top of the per-iteration cost charging PR:

- Propagate InsufficientBalanceError out of `_agent_mode_tool_executor`
  and the SDK MCP tool handler so billing failures stop the agent loop
  instead of being re-injected into the LLM as a tool error (which
  previously leaked balance amounts and let the loop keep consuming
  unpaid compute).
- On post-execution extra-iteration charging failure, record
  `execution_stats.error`, log with structured billing_leak fields,
  and fire `_handle_insufficient_funds_notif` so the user is actually
  notified. Comment now matches behaviour.
- Tighten tool-success gate to `tool_node_stats.error is None` so
  cancelled/terminated tool runs (BaseException subclasses such as
  CancelledError) are not billed.
- Extract shared `_resolve_block_cost` helper used by `_charge_usage`
  and `charge_extra_iterations` to DRY the block/cost lookup.
- Add integration tests for the `on_node_execution` charging gate
  covering each branch (status, flag, llm_call_count, dry_run) plus
  the InsufficientBalanceError path that asserts error recording and
  notification.
- Add tool-charging skip tests (dry_run, failed tool, cancelled tool)
  and an InsufficientBalanceError propagation test for
  `_execute_single_tool_with_manager`.
- Assert `charge_node_usage` is actually called in the existing
  `test_orchestrator_agent_mode` test and return a non-zero cost so
  the `merge_stats` branch is exercised.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
majdyz
2026-04-10 11:53:46 +00:00
parent 215340690f
commit ada2725628
4 changed files with 532 additions and 22 deletions

View File

@@ -1130,10 +1130,14 @@ class OrchestratorBlock(Block):
# avoid free tool execution. Charging post-completion (vs.
# pre-execution) avoids billing users for failed tool calls.
# Skipped for dry runs.
#
# `error is None` intentionally excludes both Exception and
# BaseException subclasses (e.g. CancelledError) so cancelled
# or terminated tool runs are not billed.
if (
not execution_params.execution_context.dry_run
and tool_node_stats is not None
and not isinstance(tool_node_stats.error, Exception)
and tool_node_stats.error is None
):
tool_cost, _ = await asyncio.to_thread(
execution_processor.charge_node_usage,
@@ -1277,6 +1281,13 @@ class OrchestratorBlock(Block):
content=content,
is_error=tool_failed,
)
except InsufficientBalanceError:
# Billing failures must stop the agent loop cleanly — do NOT
# downgrade them into a tool error that gets fed back to the
# LLM. Re-raise so the orchestrator's outer error handling
# halts the run (mirrors main execution queue behaviour) and
# avoids leaking exact balance amounts into LLM context.
raise
except Exception as e:
logger.error("Tool execution failed: %s", e)
return ToolCallResult(
@@ -1481,6 +1492,13 @@ class OrchestratorBlock(Block):
"content": [{"type": "text", "text": text}],
"isError": tool_failed,
}
except InsufficientBalanceError:
# Same carve-out as _agent_mode_tool_executor:
# billing failures must propagate to stop the run
# rather than be fed back to the LLM as a tool
# error (which would leak balance amounts and let
# the loop continue consuming unbillable work).
raise
except Exception as e:
logger.error("SDK tool execution failed: %s", e)
return {

View File

@@ -924,8 +924,11 @@ async def test_orchestrator_agent_mode():
)
# Mock charge_node_usage (called after successful tool execution).
# Returns (cost, remaining_balance). Synchronous because it's called
# via asyncio.to_thread.
mock_execution_processor.charge_node_usage = MagicMock(return_value=(0, 0))
# via asyncio.to_thread. Use a non-zero cost so the merge_stats
# branch is actually exercised, and assert it's called below.
mock_execution_processor.charge_node_usage = MagicMock(
return_value=(10, 990)
)
# Mock the get_execution_outputs_by_node_exec_id method
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
@@ -971,6 +974,11 @@ async def test_orchestrator_agent_mode():
# Verify tool was executed via execution processor
assert mock_execution_processor.on_node_execution.call_count == 1
# Verify charge_node_usage was actually called for the successful
# tool execution — this guards against regressions where the
# post-execution tool charging is accidentally removed.
assert mock_execution_processor.charge_node_usage.call_count == 1
@pytest.mark.asyncio
async def test_orchestrator_traditional_mode_default():

View File

@@ -6,7 +6,8 @@ this and charge ``base_cost * (llm_call_count - 1)`` extra credits after
the block completes.
"""
from unittest.mock import MagicMock
import threading
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -236,3 +237,441 @@ class TestChargeNodeUsage:
assert cost == 5
assert balance == 100
assert captured["execution_count"] == 0
# ── on_node_execution charging gate ────────────────────────────────
class _FakeNode:
"""Minimal stand-in for a ``Node`` object with a block attribute."""
def __init__(self, charge_per_llm_call: bool, block_name: str = "FakeBlock"):
self.block = MagicMock()
self.block.charge_per_llm_call = charge_per_llm_call
self.block.name = block_name
class _FakeExecContext:
def __init__(self, dry_run: bool = False):
self.dry_run = dry_run
def _make_node_exec(dry_run: bool = False) -> MagicMock:
"""Build a NodeExecutionEntry-like mock for on_node_execution tests."""
ne = MagicMock()
ne.user_id = "u"
ne.graph_id = "g"
ne.graph_exec_id = "ge"
ne.node_id = "n"
ne.node_exec_id = "ne"
ne.block_id = "b"
ne.inputs = {}
ne.execution_context = _FakeExecContext(dry_run=dry_run)
return ne
@pytest.fixture()
def gated_processor(monkeypatch):
"""ExecutionProcessor with on_node_execution's downstream calls stubbed.
Lets tests flip the four gate conditions (status, charge_per_llm_call,
llm_call_count, dry_run) and observe whether charge_extra_iterations
was called.
"""
from backend.data.execution import ExecutionStatus
from backend.data.model import NodeExecutionStats
from backend.executor import manager
calls: dict[str, list] = {
"charge_extra_iterations": [],
"handle_low_balance": [],
"handle_insufficient_funds_notif": [],
}
# Stub node lookup + DB client so the wrapper doesn't touch real infra.
fake_db = MagicMock()
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db)
monkeypatch.setattr(manager, "get_db_client", lambda: fake_db)
# get_block is called by LogMetadata construction in on_node_execution.
monkeypatch.setattr(
manager,
"get_block",
lambda block_id: MagicMock(name="FakeBlock"),
)
# Persistence + cost logging are not under test here.
monkeypatch.setattr(
manager,
"async_update_node_execution_status",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
manager,
"async_update_graph_execution_state",
AsyncMock(return_value=None),
)
monkeypatch.setattr(
manager,
"log_system_credential_cost",
AsyncMock(return_value=None),
)
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
# Control the status returned by the inner execution call.
inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3}
async def fake_inner(
self, *, node, node_exec, node_exec_progress, stats, db_client,
log_metadata, nodes_input_masks=None, nodes_to_skip=None,
):
stats.llm_call_count = inner_result["llm_call_count"]
return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"]
monkeypatch.setattr(
manager.ExecutionProcessor,
"_on_node_execution",
fake_inner,
)
def fake_charge_extra(self, node_exec, extra_iterations):
calls["charge_extra_iterations"].append(extra_iterations)
return (extra_iterations * 10, 500)
monkeypatch.setattr(
manager.ExecutionProcessor,
"charge_extra_iterations",
fake_charge_extra,
)
def fake_low_balance(self, **kwargs):
calls["handle_low_balance"].append(kwargs)
monkeypatch.setattr(
manager.ExecutionProcessor,
"_handle_low_balance",
fake_low_balance,
)
def fake_notif(self, db_client, user_id, graph_id, e):
calls["handle_insufficient_funds_notif"].append(
{"user_id": user_id, "graph_id": graph_id, "error": e}
)
monkeypatch.setattr(
manager.ExecutionProcessor,
"_handle_insufficient_funds_notif",
fake_notif,
)
return proc, calls, inner_result, fake_db, NodeExecutionStats
@pytest.mark.asyncio
async def test_on_node_execution_charges_extra_iterations_when_gate_passes(
gated_processor,
):
"""COMPLETED + charge_per_llm_call + llm_call_count>1 + not dry_run → charged."""
from backend.data.execution import ExecutionStatus
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 3 # → extra_iterations = 2
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == [2]
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_status_not_completed(gated_processor):
from backend.data.execution import ExecutionStatus
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.FAILED
inner["llm_call_count"] = 5
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_charge_flag_false(gated_processor):
from backend.data.execution import ExecutionStatus
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 5
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=False))
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_llm_call_count_le_1(gated_processor):
from backend.data.execution import ExecutionStatus
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 1 # exactly the base charge, no extras
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_skips_when_dry_run(gated_processor):
from backend.data.execution import ExecutionStatus
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 5
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=True),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
assert calls["charge_extra_iterations"] == []
@pytest.mark.asyncio
async def test_on_node_execution_insufficient_balance_records_error_and_notifies(
monkeypatch, gated_processor,
):
"""When extra-iteration charging fails with InsufficientBalanceError:
- the run still reports COMPLETED (the work is already done)
- execution_stats.error is set so monitoring picks it up
- _handle_insufficient_funds_notif is called so the user is notified
"""
from backend.data.execution import ExecutionStatus
from backend.executor import manager
from backend.util.exceptions import InsufficientBalanceError
proc, calls, inner, fake_db, _ = gated_processor
inner["status"] = ExecutionStatus.COMPLETED
inner["llm_call_count"] = 4
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
def raise_ibe(self, node_exec, extra_iterations):
raise InsufficientBalanceError(
user_id=node_exec.user_id,
message="Insufficient balance",
balance=0,
amount=extra_iterations * 10,
)
monkeypatch.setattr(
manager.ExecutionProcessor, "charge_extra_iterations", raise_ibe
)
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
cost=0, node_error_count=0), threading.Lock())
result_stats = await proc.on_node_execution(
node_exec=_make_node_exec(dry_run=False),
node_exec_progress=MagicMock(),
nodes_input_masks=None,
graph_stats_pair=stats_pair,
)
# Error recorded on stats so downstream monitoring can surface it.
assert isinstance(result_stats.error, InsufficientBalanceError)
# User notification fired.
assert len(calls["handle_insufficient_funds_notif"]) == 1
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
# ── Orchestrator _execute_single_tool_with_manager charging gates ──
async def _run_tool_exec_with_stats(
*, dry_run: bool, tool_stats_error, charge_node_usage_mock=None,
):
"""Invoke _execute_single_tool_with_manager against fully mocked deps
and return (charge_call_count, merge_stats_calls).
Used to prove the dry_run and error guards around charge_node_usage
behave as documented, and that InsufficientBalanceError propagates.
"""
import threading
from collections import defaultdict
from unittest.mock import AsyncMock, MagicMock, patch
from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock
from backend.data.execution import ExecutionContext
block = OrchestratorBlock()
# Mocked async DB client used inside orchestrator.
mock_db_client = AsyncMock()
mock_target_node = MagicMock()
mock_target_node.block_id = "test-block-id"
mock_target_node.input_default = {}
mock_db_client.get_node.return_value = mock_target_node
mock_node_exec_result = MagicMock()
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
mock_db_client.upsert_execution_input.return_value = (
mock_node_exec_result,
{"query": "t"},
)
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
"result": "ok"
}
# ExecutionProcessor mock: on_node_execution returns supplied error.
mock_processor = AsyncMock()
mock_processor.running_node_execution = defaultdict(MagicMock)
mock_processor.execution_stats = MagicMock()
mock_processor.execution_stats_lock = threading.Lock()
mock_node_stats = MagicMock()
mock_node_stats.error = tool_stats_error
mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats)
mock_processor.charge_node_usage = (
charge_node_usage_mock or MagicMock(return_value=(10, 990))
)
# Build a tool_info shaped like _build_tool_info_from_args output.
tool_call = MagicMock()
tool_call.id = "call-1"
tool_call.name = "search_keywords"
tool_call.arguments = '{"query":"t"}'
tool_def = {
"type": "function",
"function": {
"name": "search_keywords",
"_sink_node_id": "test-sink-node-id",
"_field_mapping": {},
"parameters": {
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
},
}
tool_info = OrchestratorBlock._build_tool_info_from_args(
tool_call_id="call-1",
tool_name="search_keywords",
tool_args={"query": "t"},
tool_def=tool_def,
)
exec_params = ExecutionParams(
user_id="u",
graph_id="g",
node_id="n",
graph_version=1,
graph_exec_id="ge",
node_exec_id="ne",
execution_context=ExecutionContext(
human_in_the_loop_safe_mode=False, dry_run=dry_run
),
)
with patch(
"backend.blocks.orchestrator.get_database_manager_async_client",
return_value=mock_db_client,
):
try:
await block._execute_single_tool_with_manager(
tool_info, exec_params, mock_processor, responses_api=False
)
raised = None
except Exception as e:
raised = e
return mock_processor.charge_node_usage, raised
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_dry_run():
"""dry_run=True → charge_node_usage is NOT called."""
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=True, tool_stats_error=None
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_failed_tool():
"""tool_node_stats.error is an Exception → charge_node_usage NOT called."""
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=False, tool_stats_error=RuntimeError("tool blew up")
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_skips_charging_on_cancelled_tool():
"""Cancellation (BaseException subclass) → charge_node_usage NOT called.
Guards the fix for sentry's BaseException concern: the old
`isinstance(error, Exception)` check would have treated CancelledError
as "no error" and billed the user for a terminated run.
"""
import asyncio as _asyncio
charge_mock, raised = await _run_tool_exec_with_stats(
dry_run=False, tool_stats_error=_asyncio.CancelledError()
)
assert raised is None
assert charge_mock.call_count == 0
@pytest.mark.asyncio
async def test_tool_execution_insufficient_balance_propagates():
"""InsufficientBalanceError from charge_node_usage must propagate out.
If this leaked into a ToolCallResult the LLM loop would keep running
with 'tool failed' errors and the user would get unpaid work.
"""
from unittest.mock import MagicMock
from backend.util.exceptions import InsufficientBalanceError
raising_charge = MagicMock(
side_effect=InsufficientBalanceError(
user_id="u", message="nope", balance=0, amount=10
)
)
_, raised = await _run_tool_exec_with_stats(
dry_run=False,
tool_stats_error=None,
charge_node_usage_mock=raising_charge,
)
assert isinstance(raised, InsufficientBalanceError)

View File

@@ -684,10 +684,16 @@ class ExecutionProcessor:
# is already covered by _charge_usage(); each additional LLM call
# costs another base_cost. Skipped for dry runs and failed runs.
#
# InsufficientBalanceError is logged at ERROR level (this is a
# billing leak — the work is already done, but the user can't pay)
# and re-surfaced via execution_stats.error so monitoring can pick
# it up. Other exceptions are warnings.
# InsufficientBalanceError here is a post-hoc billing leak — the
# work is already done but the user can no longer pay. We:
# 1. log at ERROR with structured fields so alerting can catch it
# 2. record the error on execution_stats.error for downstream
# monitoring (stats are persisted into node_stats below)
# 3. fire _handle_insufficient_funds_notif so the user is
# notified (mirrors the main queue path at ~line 1254)
# The run itself is kept COMPLETED (the block's outputs are
# already committed) — matching the documented "billing leak"
# contract rather than retroactively failing a successful run.
if (
status == ExecutionStatus.COMPLETED
and node.block.charge_per_llm_call
@@ -711,9 +717,36 @@ class ExecutionProcessor:
)
except InsufficientBalanceError as e:
log_metadata.error(
f"Billing leak: insufficient balance after {node.block.name} "
f"completed {extra_iterations} extra iterations: {e}"
"billing_leak: insufficient balance after "
f"{node.block.name} completed {extra_iterations} "
f"extra iterations",
extra={
"billing_leak": True,
"user_id": node_exec.user_id,
"graph_id": node_exec.graph_id,
"block_id": node_exec.block_id,
"extra_iterations": extra_iterations,
"error": str(e),
},
)
# Surface on execution_stats so node_stats persistence
# below records the billing failure for monitoring.
execution_stats.error = e
# Notify the user they're out of credits. Runs through
# Redis dedup (per user+graph) so repeat runs don't spam.
try:
await asyncio.to_thread(
self._handle_insufficient_funds_notif,
get_db_client(),
node_exec.user_id,
node_exec.graph_id,
e,
)
except Exception as notif_error: # pragma: no cover
log_metadata.warning(
f"Failed to send insufficient funds notification: "
f"{notif_error}"
)
except Exception as e:
log_metadata.warning(
f"Failed to charge extra iterations for {node.block.name}: {e}"
@@ -982,6 +1015,27 @@ class ExecutionProcessor:
stats=exec_stats,
)
def _resolve_block_cost(
self,
node_exec: NodeExecutionEntry,
) -> tuple[Any, int, dict]:
"""Look up the block and compute its base usage cost for an exec.
Shared by :meth:`_charge_usage` and :meth:`charge_extra_iterations`
so the (get_block, block_usage_cost) lookup lives in exactly one
place. Returns ``(block, cost, matching_filter)``. ``block`` is
``None`` if the block id can't be resolved — callers should treat
that as "nothing to charge".
"""
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return None, 0, {}
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
return block, cost, matching_filter
def _charge_usage(
self,
node_exec: NodeExecutionEntry,
@@ -990,14 +1044,10 @@ class ExecutionProcessor:
total_cost = 0
remaining_balance = 0
db_client = get_db_client()
block = get_block(node_exec.block_id)
block, cost, matching_filter = self._resolve_block_cost(node_exec)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return total_cost, 0
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost > 0:
remaining_balance = db_client.spend_credits(
user_id=node_exec.user_id,
@@ -1081,13 +1131,8 @@ class ExecutionProcessor:
# Cap to protect against a corrupted llm_call_count.
capped_iterations = min(extra_iterations, self._MAX_EXTRA_ITERATIONS)
db_client = get_db_client()
block = get_block(node_exec.block_id)
if not block:
return 0, 0
cost, matching_filter = block_usage_cost(
block=block, input_data=node_exec.inputs
)
if cost <= 0:
block, cost, matching_filter = self._resolve_block_cost(node_exec)
if not block or cost <= 0:
return 0, 0
total_extra_cost = cost * capped_iterations
remaining_balance = db_client.spend_credits(