mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(orchestrator): propagate billing errors and close leak windows
Round 3 review fixes on top of the per-iteration cost charging PR: - Propagate InsufficientBalanceError out of `_agent_mode_tool_executor` and the SDK MCP tool handler so billing failures stop the agent loop instead of being re-injected into the LLM as a tool error (which previously leaked balance amounts and let the loop keep consuming unpaid compute). - On post-execution extra-iteration charging failure, record `execution_stats.error`, log with structured billing_leak fields, and fire `_handle_insufficient_funds_notif` so the user is actually notified. Comment now matches behaviour. - Tighten tool-success gate to `tool_node_stats.error is None` so cancelled/terminated tool runs (BaseException subclasses such as CancelledError) are not billed. - Extract shared `_resolve_block_cost` helper used by `_charge_usage` and `charge_extra_iterations` to DRY the block/cost lookup. - Add integration tests for the `on_node_execution` charging gate covering each branch (status, flag, llm_call_count, dry_run) plus the InsufficientBalanceError path that asserts error recording and notification. - Add tool-charging skip tests (dry_run, failed tool, cancelled tool) and an InsufficientBalanceError propagation test for `_execute_single_tool_with_manager`. - Assert `charge_node_usage` is actually called in the existing `test_orchestrator_agent_mode` test and return a non-zero cost so the `merge_stats` branch is exercised. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1130,10 +1130,14 @@ class OrchestratorBlock(Block):
|
||||
# avoid free tool execution. Charging post-completion (vs.
|
||||
# pre-execution) avoids billing users for failed tool calls.
|
||||
# Skipped for dry runs.
|
||||
#
|
||||
# `error is None` intentionally excludes both Exception and
|
||||
# BaseException subclasses (e.g. CancelledError) so cancelled
|
||||
# or terminated tool runs are not billed.
|
||||
if (
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats is not None
|
||||
and not isinstance(tool_node_stats.error, Exception)
|
||||
and tool_node_stats.error is None
|
||||
):
|
||||
tool_cost, _ = await asyncio.to_thread(
|
||||
execution_processor.charge_node_usage,
|
||||
@@ -1277,6 +1281,13 @@ class OrchestratorBlock(Block):
|
||||
content=content,
|
||||
is_error=tool_failed,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# Billing failures must stop the agent loop cleanly — do NOT
|
||||
# downgrade them into a tool error that gets fed back to the
|
||||
# LLM. Re-raise so the orchestrator's outer error handling
|
||||
# halts the run (mirrors main execution queue behaviour) and
|
||||
# avoids leaking exact balance amounts into LLM context.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
return ToolCallResult(
|
||||
@@ -1481,6 +1492,13 @@ class OrchestratorBlock(Block):
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": tool_failed,
|
||||
}
|
||||
except InsufficientBalanceError:
|
||||
# Same carve-out as _agent_mode_tool_executor:
|
||||
# billing failures must propagate to stop the run
|
||||
# rather than be fed back to the LLM as a tool
|
||||
# error (which would leak balance amounts and let
|
||||
# the loop continue consuming unbillable work).
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("SDK tool execution failed: %s", e)
|
||||
return {
|
||||
|
||||
@@ -924,8 +924,11 @@ async def test_orchestrator_agent_mode():
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Synchronous because it's called
|
||||
# via asyncio.to_thread.
|
||||
mock_execution_processor.charge_node_usage = MagicMock(return_value=(0, 0))
|
||||
# via asyncio.to_thread. Use a non-zero cost so the merge_stats
|
||||
# branch is actually exercised, and assert it's called below.
|
||||
mock_execution_processor.charge_node_usage = MagicMock(
|
||||
return_value=(10, 990)
|
||||
)
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
@@ -971,6 +974,11 @@ async def test_orchestrator_agent_mode():
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
# Verify charge_node_usage was actually called for the successful
|
||||
# tool execution — this guards against regressions where the
|
||||
# post-execution tool charging is accidentally removed.
|
||||
assert mock_execution_processor.charge_node_usage.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
|
||||
@@ -6,7 +6,8 @@ this and charge ``base_cost * (llm_call_count - 1)`` extra credits after
|
||||
the block completes.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -236,3 +237,441 @@ class TestChargeNodeUsage:
|
||||
assert cost == 5
|
||||
assert balance == 100
|
||||
assert captured["execution_count"] == 0
|
||||
|
||||
|
||||
# ── on_node_execution charging gate ────────────────────────────────
|
||||
|
||||
|
||||
class _FakeNode:
|
||||
"""Minimal stand-in for a ``Node`` object with a block attribute."""
|
||||
|
||||
def __init__(self, charge_per_llm_call: bool, block_name: str = "FakeBlock"):
|
||||
self.block = MagicMock()
|
||||
self.block.charge_per_llm_call = charge_per_llm_call
|
||||
self.block.name = block_name
|
||||
|
||||
|
||||
class _FakeExecContext:
|
||||
def __init__(self, dry_run: bool = False):
|
||||
self.dry_run = dry_run
|
||||
|
||||
|
||||
def _make_node_exec(dry_run: bool = False) -> MagicMock:
|
||||
"""Build a NodeExecutionEntry-like mock for on_node_execution tests."""
|
||||
ne = MagicMock()
|
||||
ne.user_id = "u"
|
||||
ne.graph_id = "g"
|
||||
ne.graph_exec_id = "ge"
|
||||
ne.node_id = "n"
|
||||
ne.node_exec_id = "ne"
|
||||
ne.block_id = "b"
|
||||
ne.inputs = {}
|
||||
ne.execution_context = _FakeExecContext(dry_run=dry_run)
|
||||
return ne
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def gated_processor(monkeypatch):
|
||||
"""ExecutionProcessor with on_node_execution's downstream calls stubbed.
|
||||
|
||||
Lets tests flip the four gate conditions (status, charge_per_llm_call,
|
||||
llm_call_count, dry_run) and observe whether charge_extra_iterations
|
||||
was called.
|
||||
"""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor import manager
|
||||
|
||||
calls: dict[str, list] = {
|
||||
"charge_extra_iterations": [],
|
||||
"handle_low_balance": [],
|
||||
"handle_insufficient_funds_notif": [],
|
||||
}
|
||||
|
||||
# Stub node lookup + DB client so the wrapper doesn't touch real infra.
|
||||
fake_db = MagicMock()
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db)
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: fake_db)
|
||||
# get_block is called by LogMetadata construction in on_node_execution.
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"get_block",
|
||||
lambda block_id: MagicMock(name="FakeBlock"),
|
||||
)
|
||||
# Persistence + cost logging are not under test here.
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"async_update_node_execution_status",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"async_update_graph_execution_state",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"log_system_credential_cost",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
|
||||
# Control the status returned by the inner execution call.
|
||||
inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3}
|
||||
|
||||
async def fake_inner(
|
||||
self, *, node, node_exec, node_exec_progress, stats, db_client,
|
||||
log_metadata, nodes_input_masks=None, nodes_to_skip=None,
|
||||
):
|
||||
stats.llm_call_count = inner_result["llm_call_count"]
|
||||
return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_on_node_execution",
|
||||
fake_inner,
|
||||
)
|
||||
|
||||
def fake_charge_extra(self, node_exec, extra_iterations):
|
||||
calls["charge_extra_iterations"].append(extra_iterations)
|
||||
return (extra_iterations * 10, 500)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"charge_extra_iterations",
|
||||
fake_charge_extra,
|
||||
)
|
||||
|
||||
def fake_low_balance(self, **kwargs):
|
||||
calls["handle_low_balance"].append(kwargs)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_handle_low_balance",
|
||||
fake_low_balance,
|
||||
)
|
||||
|
||||
def fake_notif(self, db_client, user_id, graph_id, e):
|
||||
calls["handle_insufficient_funds_notif"].append(
|
||||
{"user_id": user_id, "graph_id": graph_id, "error": e}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_handle_insufficient_funds_notif",
|
||||
fake_notif,
|
||||
)
|
||||
|
||||
return proc, calls, inner_result, fake_db, NodeExecutionStats
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_charges_extra_iterations_when_gate_passes(
|
||||
gated_processor,
|
||||
):
|
||||
"""COMPLETED + charge_per_llm_call + llm_call_count>1 + not dry_run → charged."""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 3 # → extra_iterations = 2
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == [2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_status_not_completed(gated_processor):
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.FAILED
|
||||
inner["llm_call_count"] = 5
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_charge_flag_false(gated_processor):
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 5
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=False))
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_llm_call_count_le_1(gated_processor):
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 1 # exactly the base charge, no extras
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_dry_run(gated_processor):
|
||||
from backend.data.execution import ExecutionStatus
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 5
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=True),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_insufficient_balance_records_error_and_notifies(
|
||||
monkeypatch, gated_processor,
|
||||
):
|
||||
"""When extra-iteration charging fails with InsufficientBalanceError:
|
||||
|
||||
- the run still reports COMPLETED (the work is already done)
|
||||
- execution_stats.error is set so monitoring picks it up
|
||||
- _handle_insufficient_funds_notif is called so the user is notified
|
||||
"""
|
||||
from backend.data.execution import ExecutionStatus
|
||||
from backend.executor import manager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 4
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(charge_per_llm_call=True))
|
||||
|
||||
def raise_ibe(self, node_exec, extra_iterations):
|
||||
raise InsufficientBalanceError(
|
||||
user_id=node_exec.user_id,
|
||||
message="Insufficient balance",
|
||||
balance=0,
|
||||
amount=extra_iterations * 10,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "charge_extra_iterations", raise_ibe
|
||||
)
|
||||
|
||||
stats_pair = (MagicMock(node_count=0, nodes_cputime=0, nodes_walltime=0,
|
||||
cost=0, node_error_count=0), threading.Lock())
|
||||
result_stats = await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
# Error recorded on stats so downstream monitoring can surface it.
|
||||
assert isinstance(result_stats.error, InsufficientBalanceError)
|
||||
# User notification fired.
|
||||
assert len(calls["handle_insufficient_funds_notif"]) == 1
|
||||
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
|
||||
|
||||
|
||||
# ── Orchestrator _execute_single_tool_with_manager charging gates ──
|
||||
|
||||
|
||||
async def _run_tool_exec_with_stats(
|
||||
*, dry_run: bool, tool_stats_error, charge_node_usage_mock=None,
|
||||
):
|
||||
"""Invoke _execute_single_tool_with_manager against fully mocked deps
|
||||
and return (charge_call_count, merge_stats_calls).
|
||||
|
||||
Used to prove the dry_run and error guards around charge_node_usage
|
||||
behave as documented, and that InsufficientBalanceError propagates.
|
||||
"""
|
||||
import threading
|
||||
from collections import defaultdict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock
|
||||
from backend.data.execution import ExecutionContext
|
||||
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mocked async DB client used inside orchestrator.
|
||||
mock_db_client = AsyncMock()
|
||||
mock_target_node = MagicMock()
|
||||
mock_target_node.block_id = "test-block-id"
|
||||
mock_target_node.input_default = {}
|
||||
mock_db_client.get_node.return_value = mock_target_node
|
||||
mock_node_exec_result = MagicMock()
|
||||
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
{"query": "t"},
|
||||
)
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
"result": "ok"
|
||||
}
|
||||
|
||||
# ExecutionProcessor mock: on_node_execution returns supplied error.
|
||||
mock_processor = AsyncMock()
|
||||
mock_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_processor.execution_stats = MagicMock()
|
||||
mock_processor.execution_stats_lock = threading.Lock()
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = tool_stats_error
|
||||
mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats)
|
||||
mock_processor.charge_node_usage = (
|
||||
charge_node_usage_mock or MagicMock(return_value=(10, 990))
|
||||
)
|
||||
|
||||
# Build a tool_info shaped like _build_tool_info_from_args output.
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call-1"
|
||||
tool_call.name = "search_keywords"
|
||||
tool_call.arguments = '{"query":"t"}'
|
||||
tool_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {},
|
||||
"parameters": {
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tool_info = OrchestratorBlock._build_tool_info_from_args(
|
||||
tool_call_id="call-1",
|
||||
tool_name="search_keywords",
|
||||
tool_args={"query": "t"},
|
||||
tool_def=tool_def,
|
||||
)
|
||||
|
||||
exec_params = ExecutionParams(
|
||||
user_id="u",
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_version=1,
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
execution_context=ExecutionContext(
|
||||
human_in_the_loop_safe_mode=False, dry_run=dry_run
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
try:
|
||||
await block._execute_single_tool_with_manager(
|
||||
tool_info, exec_params, mock_processor, responses_api=False
|
||||
)
|
||||
raised = None
|
||||
except Exception as e:
|
||||
raised = e
|
||||
|
||||
return mock_processor.charge_node_usage, raised
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_dry_run():
|
||||
"""dry_run=True → charge_node_usage is NOT called."""
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=True, tool_stats_error=None
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_failed_tool():
|
||||
"""tool_node_stats.error is an Exception → charge_node_usage NOT called."""
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False, tool_stats_error=RuntimeError("tool blew up")
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_cancelled_tool():
|
||||
"""Cancellation (BaseException subclass) → charge_node_usage NOT called.
|
||||
|
||||
Guards the fix for sentry's BaseException concern: the old
|
||||
`isinstance(error, Exception)` check would have treated CancelledError
|
||||
as "no error" and billed the user for a terminated run.
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False, tool_stats_error=_asyncio.CancelledError()
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_insufficient_balance_propagates():
|
||||
"""InsufficientBalanceError from charge_node_usage must propagate out.
|
||||
|
||||
If this leaked into a ToolCallResult the LLM loop would keep running
|
||||
with 'tool failed' errors and the user would get unpaid work.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
raising_charge = MagicMock(
|
||||
side_effect=InsufficientBalanceError(
|
||||
user_id="u", message="nope", balance=0, amount=10
|
||||
)
|
||||
)
|
||||
_, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False,
|
||||
tool_stats_error=None,
|
||||
charge_node_usage_mock=raising_charge,
|
||||
)
|
||||
assert isinstance(raised, InsufficientBalanceError)
|
||||
|
||||
@@ -684,10 +684,16 @@ class ExecutionProcessor:
|
||||
# is already covered by _charge_usage(); each additional LLM call
|
||||
# costs another base_cost. Skipped for dry runs and failed runs.
|
||||
#
|
||||
# InsufficientBalanceError is logged at ERROR level (this is a
|
||||
# billing leak — the work is already done, but the user can't pay)
|
||||
# and re-surfaced via execution_stats.error so monitoring can pick
|
||||
# it up. Other exceptions are warnings.
|
||||
# InsufficientBalanceError here is a post-hoc billing leak — the
|
||||
# work is already done but the user can no longer pay. We:
|
||||
# 1. log at ERROR with structured fields so alerting can catch it
|
||||
# 2. record the error on execution_stats.error for downstream
|
||||
# monitoring (stats are persisted into node_stats below)
|
||||
# 3. fire _handle_insufficient_funds_notif so the user is
|
||||
# notified (mirrors the main queue path at ~line 1254)
|
||||
# The run itself is kept COMPLETED (the block's outputs are
|
||||
# already committed) — matching the documented "billing leak"
|
||||
# contract rather than retroactively failing a successful run.
|
||||
if (
|
||||
status == ExecutionStatus.COMPLETED
|
||||
and node.block.charge_per_llm_call
|
||||
@@ -711,9 +717,36 @@ class ExecutionProcessor:
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
log_metadata.error(
|
||||
f"Billing leak: insufficient balance after {node.block.name} "
|
||||
f"completed {extra_iterations} extra iterations: {e}"
|
||||
"billing_leak: insufficient balance after "
|
||||
f"{node.block.name} completed {extra_iterations} "
|
||||
f"extra iterations",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_iterations": extra_iterations,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
# Surface on execution_stats so node_stats persistence
|
||||
# below records the billing failure for monitoring.
|
||||
execution_stats.error = e
|
||||
# Notify the user they're out of credits. Runs through
|
||||
# Redis dedup (per user+graph) so repeat runs don't spam.
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._handle_insufficient_funds_notif,
|
||||
get_db_client(),
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
e,
|
||||
)
|
||||
except Exception as notif_error: # pragma: no cover
|
||||
log_metadata.warning(
|
||||
f"Failed to send insufficient funds notification: "
|
||||
f"{notif_error}"
|
||||
)
|
||||
except Exception as e:
|
||||
log_metadata.warning(
|
||||
f"Failed to charge extra iterations for {node.block.name}: {e}"
|
||||
@@ -982,6 +1015,27 @@ class ExecutionProcessor:
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
def _resolve_block_cost(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple[Any, int, dict]:
|
||||
"""Look up the block and compute its base usage cost for an exec.
|
||||
|
||||
Shared by :meth:`_charge_usage` and :meth:`charge_extra_iterations`
|
||||
so the (get_block, block_usage_cost) lookup lives in exactly one
|
||||
place. Returns ``(block, cost, matching_filter)``. ``block`` is
|
||||
``None`` if the block id can't be resolved — callers should treat
|
||||
that as "nothing to charge".
|
||||
"""
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return None, 0, {}
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
return block, cost, matching_filter
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
@@ -990,14 +1044,10 @@ class ExecutionProcessor:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
block, cost, matching_filter = self._resolve_block_cost(node_exec)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost, 0
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -1081,13 +1131,8 @@ class ExecutionProcessor:
|
||||
# Cap to protect against a corrupted llm_call_count.
|
||||
capped_iterations = min(extra_iterations, self._MAX_EXTRA_ITERATIONS)
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
return 0, 0
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost <= 0:
|
||||
block, cost, matching_filter = self._resolve_block_cost(node_exec)
|
||||
if not block or cost <= 0:
|
||||
return 0, 0
|
||||
total_extra_cost = cost * capped_iterations
|
||||
remaining_balance = db_client.spend_credits(
|
||||
|
||||
Reference in New Issue
Block a user