mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge branch 'fix/orchestrator-per-iteration-cost' of https://github.com/Significant-Gravitas/AutoGPT into preview/all-prs
This commit is contained in:
@@ -420,6 +420,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig):
|
||||
class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
_optimized_description: ClassVar[str | None] = None
|
||||
|
||||
def extra_credit_charges(self, execution_stats: "NodeExecutionStats") -> int:
|
||||
"""Return extra credits to charge after this block run completes.
|
||||
|
||||
Called by the executor after a block finishes with COMPLETED status.
|
||||
The return value is the number of additional base-cost credits to
|
||||
charge beyond the single credit already collected by ``_charge_usage``
|
||||
at the start of execution. Defaults to 0 (no extra charges).
|
||||
|
||||
Override in blocks (e.g. OrchestratorBlock) that make multiple LLM
|
||||
calls within one run and should be billed per call.
|
||||
"""
|
||||
return 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str = "",
|
||||
|
||||
@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
from backend.util.security import SENSITIVE_FIELD_NAMES
|
||||
from backend.util.tool_call_loop import (
|
||||
@@ -364,10 +365,25 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None:
|
||||
|
||||
|
||||
class OrchestratorBlock(Block):
|
||||
"""A block that uses a language model to orchestrate tool calls.
|
||||
|
||||
Supports both single-shot and iterative agent mode execution.
|
||||
|
||||
**InsufficientBalanceError propagation contract**: ``InsufficientBalanceError``
|
||||
(IBE) must always re-raise through every ``except`` block in this class.
|
||||
Swallowing IBE would let the agent loop continue with unpaid work. Every
|
||||
exception handler that catches ``Exception`` includes an explicit IBE
|
||||
re-raise carve-out for this reason.
|
||||
"""
|
||||
A block that uses a language model to orchestrate tool calls, supporting both
|
||||
single-shot and iterative agent mode execution.
|
||||
"""
|
||||
|
||||
def extra_credit_charges(self, execution_stats: NodeExecutionStats) -> int:
|
||||
"""Charge one extra base credit per LLM call beyond the first.
|
||||
|
||||
In agent mode each iteration makes one LLM call. The first is already
|
||||
covered by _charge_usage(); this returns the number of additional
|
||||
credits so the executor can bill the remaining calls post-completion.
|
||||
"""
|
||||
return max(0, execution_stats.llm_call_count - 1)
|
||||
|
||||
# MCP server name used by the Claude Code SDK execution mode. Keep in sync
|
||||
# with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode.
|
||||
@@ -1077,7 +1093,10 @@ class OrchestratorBlock(Block):
|
||||
input_data=input_value,
|
||||
)
|
||||
|
||||
assert node_exec_result is not None, "node_exec_result should not be None"
|
||||
if node_exec_result is None:
|
||||
raise RuntimeError(
|
||||
f"upsert_execution_input returned None for node {sink_node_id}"
|
||||
)
|
||||
|
||||
# Create NodeExecutionEntry for execution manager
|
||||
node_exec_entry = NodeExecutionEntry(
|
||||
@@ -1112,15 +1131,79 @@ class OrchestratorBlock(Block):
|
||||
task=node_exec_future,
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
# Execute the node directly since we're in the Orchestrator context.
|
||||
# Wrap in try/except so the future is always resolved, even on
|
||||
# error — an unresolved Future would block anything awaiting it.
|
||||
#
|
||||
# on_node_execution is decorated with @async_error_logged(swallow=True),
|
||||
# which catches BaseException and returns None rather than raising.
|
||||
# Treat a None return as a failure: set_exception so the future
|
||||
# carries an error state rather than a None result, and return an
|
||||
# error response so the LLM knows the tool failed.
|
||||
try:
|
||||
tool_node_stats = await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
)
|
||||
if tool_node_stats is None:
|
||||
nil_err = RuntimeError(
|
||||
f"on_node_execution returned None for node {sink_node_id} "
|
||||
"(error was swallowed by @async_error_logged)"
|
||||
)
|
||||
node_exec_future.set_exception(nil_err)
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
"Tool execution returned no result",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
node_exec_future.set_result(tool_node_stats)
|
||||
except Exception as exec_err:
|
||||
node_exec_future.set_exception(exec_err)
|
||||
raise
|
||||
|
||||
# Charge user credits AFTER successful tool execution. Tools
|
||||
# spawned by the orchestrator bypass the main execution queue
|
||||
# (where _charge_usage is called), so we must charge here to
|
||||
# avoid free tool execution. Charging post-completion (vs.
|
||||
# pre-execution) avoids billing users for failed tool calls.
|
||||
# Skipped for dry runs.
|
||||
#
|
||||
# `error is None` intentionally excludes both Exception and
|
||||
# BaseException subclasses (e.g. CancelledError) so cancelled
|
||||
# or terminated tool runs are not billed.
|
||||
#
|
||||
# Billing errors (including non-balance exceptions) are kept
|
||||
# in a separate try/except so they are never silently swallowed
|
||||
# by the generic tool-error handler below.
|
||||
if (
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats.error is None
|
||||
):
|
||||
try:
|
||||
tool_cost, _ = await execution_processor.charge_node_usage(
|
||||
node_exec_entry,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see OrchestratorBlock class docstring.
|
||||
raise
|
||||
except Exception:
|
||||
# Non-billing charge failures (DB outage, network, etc.)
|
||||
# must NOT propagate to the outer except handler because
|
||||
# the tool itself succeeded. Re-raising would mark the
|
||||
# tool as failed (_is_error=True), causing the LLM to
|
||||
# retry side-effectful operations. Log and continue.
|
||||
logger.exception(
|
||||
"Unexpected error charging for tool node %s; "
|
||||
"tool execution was successful",
|
||||
sink_node_id,
|
||||
)
|
||||
tool_cost = 0
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
@@ -1133,18 +1216,26 @@ class OrchestratorBlock(Block):
|
||||
if node_outputs
|
||||
else "Tool executed successfully"
|
||||
)
|
||||
return _create_tool_response(
|
||||
resp = _create_tool_response(
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
resp["_is_error"] = False
|
||||
return resp
|
||||
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Tool execution with manager failed: %s", e)
|
||||
# Return error response
|
||||
return _create_tool_response(
|
||||
logger.warning("Tool execution with manager failed: %s", e, exc_info=True)
|
||||
# Return a generic error to the LLM — internal exception messages
|
||||
# may contain server paths, DB details, or infrastructure info.
|
||||
resp = _create_tool_response(
|
||||
tool_call.id,
|
||||
f"Tool execution failed: {e}",
|
||||
"Tool execution failed due to an internal error",
|
||||
responses_api=responses_api,
|
||||
)
|
||||
resp["_is_error"] = True
|
||||
return resp
|
||||
|
||||
async def _agent_mode_llm_caller(
|
||||
self,
|
||||
@@ -1244,13 +1335,16 @@ class OrchestratorBlock(Block):
|
||||
content = str(raw_content)
|
||||
else:
|
||||
content = "Tool executed successfully"
|
||||
tool_failed = content.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_call.name,
|
||||
content=content,
|
||||
is_error=tool_failed,
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tool execution failed: %s", e)
|
||||
return ToolCallResult(
|
||||
@@ -1370,9 +1464,13 @@ class OrchestratorBlock(Block):
|
||||
"arguments": tc.arguments,
|
||||
},
|
||||
)
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Catch all errors (validation, network, API) so that the block
|
||||
# surfaces them as user-visible output instead of crashing.
|
||||
# Catch all OTHER errors (validation, network, API) so that
|
||||
# the block surfaces them as user-visible output instead of
|
||||
# crashing.
|
||||
yield "error", str(e)
|
||||
return
|
||||
|
||||
@@ -1450,11 +1548,14 @@ class OrchestratorBlock(Block):
|
||||
text = content
|
||||
else:
|
||||
text = json.dumps(content)
|
||||
tool_failed = text.startswith("Tool execution failed:")
|
||||
tool_failed = result.get("_is_error", True)
|
||||
return {
|
||||
"content": [{"type": "text", "text": text}],
|
||||
"isError": tool_failed,
|
||||
}
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("SDK tool execution failed: %s", e)
|
||||
return {
|
||||
@@ -1733,11 +1834,15 @@ class OrchestratorBlock(Block):
|
||||
await pending_task
|
||||
except (asyncio.CancelledError, StopAsyncIteration):
|
||||
pass
|
||||
except InsufficientBalanceError:
|
||||
# IBE must propagate — see class docstring. The `finally`
|
||||
# block below still runs and records partial token usage.
|
||||
raise
|
||||
except Exception as e:
|
||||
# Surface SDK errors as user-visible output instead of crashing,
|
||||
# consistent with _execute_tools_agent_mode error handling.
|
||||
# Don't return yet — fall through to merge_stats below so
|
||||
# partial token usage is always recorded.
|
||||
# Surface OTHER SDK errors as user-visible output instead
|
||||
# of crashing, consistent with _execute_tools_agent_mode
|
||||
# error handling. Don't return yet — fall through to
|
||||
# merge_stats below so partial token usage is always recorded.
|
||||
sdk_error = e
|
||||
finally:
|
||||
# Always record usage stats, even on error. The SDK may have
|
||||
|
||||
@@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode():
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Must be AsyncMock because it is
|
||||
# an async method and is directly awaited in _execute_single_tool_with_manager.
|
||||
# Use a non-zero cost so the merge_stats branch is exercised.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990))
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
@@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode():
|
||||
# Verify tool was executed via execution processor
|
||||
assert mock_execution_processor.on_node_execution.call_count == 1
|
||||
|
||||
# Verify charge_node_usage was actually called for the successful
|
||||
# tool execution — this guards against regressions where the
|
||||
# post-execution tool charging is accidentally removed.
|
||||
assert mock_execution_processor.charge_node_usage.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_orchestrator_traditional_mode_default():
|
||||
|
||||
@@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would
|
||||
# return a non-awaitable tuple and TypeError out, then be
|
||||
# silently swallowed by the orchestrator's catch-all.
|
||||
mock_execution_processor.charge_node_usage = AsyncMock(
|
||||
return_value=(0, 0)
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
|
||||
@@ -0,0 +1,973 @@
|
||||
"""Tests for OrchestratorBlock per-iteration cost charging.
|
||||
|
||||
The OrchestratorBlock in agent mode makes multiple LLM calls in a single
|
||||
node execution. The executor uses ``Block.extra_credit_charges`` to detect
|
||||
this and charge ``base_cost * (llm_call_count - 1)`` extra credits after
|
||||
the block completes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks._base import Block
|
||||
from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock
|
||||
from backend.data.execution import ExecutionContext, ExecutionStatus
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.executor import manager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
# ── extra_credit_charges hook ────────────────────────────────────────
|
||||
|
||||
|
||||
class _NoOpBlock(Block):
|
||||
"""Minimal concrete Block subclass that does not override extra_credit_charges."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="noop-block", description="No-op test block")
|
||||
|
||||
def run(self, input_data, **kwargs): # type: ignore[override]
|
||||
yield "out", {}
|
||||
|
||||
|
||||
class TestExtraCreditCharges:
|
||||
"""OrchestratorBlock opts into per-LLM-call billing via extra_credit_charges."""
|
||||
|
||||
def test_orchestrator_returns_nonzero_for_multiple_calls(self):
|
||||
block = OrchestratorBlock()
|
||||
stats = NodeExecutionStats(llm_call_count=3)
|
||||
assert block.extra_credit_charges(stats) == 2
|
||||
|
||||
def test_orchestrator_returns_zero_for_single_call(self):
|
||||
block = OrchestratorBlock()
|
||||
stats = NodeExecutionStats(llm_call_count=1)
|
||||
assert block.extra_credit_charges(stats) == 0
|
||||
|
||||
def test_orchestrator_returns_zero_for_zero_calls(self):
|
||||
block = OrchestratorBlock()
|
||||
stats = NodeExecutionStats(llm_call_count=0)
|
||||
assert block.extra_credit_charges(stats) == 0
|
||||
|
||||
def test_default_block_returns_zero(self):
|
||||
"""A block that does not override extra_credit_charges returns 0."""
|
||||
block = _NoOpBlock()
|
||||
stats = NodeExecutionStats(llm_call_count=10)
|
||||
assert block.extra_credit_charges(stats) == 0
|
||||
|
||||
|
||||
# ── charge_extra_iterations math ───────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_node_exec():
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
return node_exec
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_processor(monkeypatch):
|
||||
"""ExecutionProcessor with stubbed db client / block lookup helpers.
|
||||
|
||||
Returns the processor and a list of credit amounts spent so tests can
|
||||
assert on what was charged.
|
||||
|
||||
Note: ``ExecutionProcessor.__new__()`` bypasses ``__init__`` — if
|
||||
``__init__`` gains required state in the future this fixture will need
|
||||
updating.
|
||||
"""
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 1000 # remaining balance
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"block_usage_cost",
|
||||
lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
return proc, spent
|
||||
|
||||
|
||||
class TestChargeExtraIterations:
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_extra_iterations_charges_nothing(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=0
|
||||
)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_iterations_multiplies_base_cost(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=4
|
||||
)
|
||||
assert cost == 40 # 4 × 10
|
||||
assert balance == 1000
|
||||
assert spent == [40]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_negative_extra_iterations_charges_nothing(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=-1
|
||||
)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capped_at_max(self, monkeypatch, fake_node_exec):
|
||||
"""Runaway llm_call_count is capped at _MAX_EXTRA_ITERATIONS."""
|
||||
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 1000
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"block_usage_cost",
|
||||
lambda block, input_data, **_kw: (10, {}),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cap = manager.ExecutionProcessor._MAX_EXTRA_ITERATIONS
|
||||
cost, _ = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=cap * 100
|
||||
)
|
||||
# Charged at most cap × 10
|
||||
assert cost == cap * 10
|
||||
assert spent == [cap * 10]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec):
|
||||
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 0
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (0, {})
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=4
|
||||
)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec):
|
||||
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 0
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: None)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = await proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=3
|
||||
)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagates_insufficient_balance_error(
|
||||
self, monkeypatch, fake_node_exec
|
||||
):
|
||||
"""Out-of-credits errors must propagate, not be silently swallowed."""
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
raise InsufficientBalanceError(
|
||||
user_id=user_id,
|
||||
message="Insufficient balance",
|
||||
balance=0,
|
||||
amount=cost,
|
||||
)
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
with pytest.raises(InsufficientBalanceError):
|
||||
await proc.charge_extra_iterations(fake_node_exec, extra_iterations=4)
|
||||
|
||||
|
||||
# ── charge_node_usage ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChargeNodeUsage:
|
||||
"""charge_node_usage delegates to _charge_usage with execution_count=0."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delegates_with_zero_execution_count(
|
||||
self, monkeypatch, fake_node_exec
|
||||
):
|
||||
"""Nested tool charges should NOT inflate the per-execution counter."""
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def fake_charge_usage(self, node_exec, execution_count):
|
||||
captured["execution_count"] = execution_count
|
||||
captured["node_exec"] = node_exec
|
||||
return (5, 100)
|
||||
|
||||
def fake_handle_low_balance(
|
||||
self, db_client, user_id, current_balance, transaction_cost
|
||||
):
|
||||
pass
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
|
||||
)
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = await proc.charge_node_usage(fake_node_exec)
|
||||
assert cost == 5
|
||||
assert balance == 100
|
||||
assert captured["execution_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_calls_handle_low_balance_when_cost_nonzero(
|
||||
self, monkeypatch, fake_node_exec
|
||||
):
|
||||
"""charge_node_usage should call _handle_low_balance when total_cost > 0."""
|
||||
|
||||
low_balance_calls: list[dict] = []
|
||||
|
||||
def fake_charge_usage(self, node_exec, execution_count):
|
||||
return (10, 50)
|
||||
|
||||
def fake_handle_low_balance(
|
||||
self, db_client, user_id, current_balance, transaction_cost
|
||||
):
|
||||
low_balance_calls.append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"current_balance": current_balance,
|
||||
"transaction_cost": transaction_cost,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
|
||||
)
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = await proc.charge_node_usage(fake_node_exec)
|
||||
assert cost == 10
|
||||
assert balance == 50
|
||||
assert len(low_balance_calls) == 1
|
||||
assert low_balance_calls[0]["user_id"] == "u"
|
||||
assert low_balance_calls[0]["current_balance"] == 50
|
||||
assert low_balance_calls[0]["transaction_cost"] == 10
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_handle_low_balance_when_cost_zero(
|
||||
self, monkeypatch, fake_node_exec
|
||||
):
|
||||
"""charge_node_usage should NOT call _handle_low_balance when cost is 0."""
|
||||
|
||||
low_balance_calls: list = []
|
||||
|
||||
def fake_charge_usage(self, node_exec, execution_count):
|
||||
return (0, 200)
|
||||
|
||||
def fake_handle_low_balance(
|
||||
self, db_client, user_id, current_balance, transaction_cost
|
||||
):
|
||||
low_balance_calls.append(True)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance
|
||||
)
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock())
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = await proc.charge_node_usage(fake_node_exec)
|
||||
assert cost == 0
|
||||
assert low_balance_calls == []
|
||||
|
||||
|
||||
# ── on_node_execution charging gate ────────────────────────────────
|
||||
|
||||
|
||||
class _FakeNode:
|
||||
"""Minimal stand-in for a ``Node`` object with a block attribute."""
|
||||
|
||||
def __init__(self, extra_charges: int = 0, block_name: str = "FakeBlock"):
|
||||
self.block = MagicMock()
|
||||
self.block.name = block_name
|
||||
self.block.extra_credit_charges = MagicMock(return_value=extra_charges)
|
||||
|
||||
|
||||
class _FakeExecContext:
|
||||
def __init__(self, dry_run: bool = False):
|
||||
self.dry_run = dry_run
|
||||
|
||||
|
||||
def _make_node_exec(dry_run: bool = False) -> MagicMock:
|
||||
"""Build a NodeExecutionEntry-like mock for on_node_execution tests."""
|
||||
ne = MagicMock()
|
||||
ne.user_id = "u"
|
||||
ne.graph_id = "g"
|
||||
ne.graph_exec_id = "ge"
|
||||
ne.node_id = "n"
|
||||
ne.node_exec_id = "ne"
|
||||
ne.block_id = "b"
|
||||
ne.inputs = {}
|
||||
ne.execution_context = _FakeExecContext(dry_run=dry_run)
|
||||
return ne
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def gated_processor(monkeypatch):
|
||||
"""ExecutionProcessor with on_node_execution's downstream calls stubbed.
|
||||
|
||||
Lets tests flip the gate conditions (status, extra_credit_charges result,
|
||||
llm_call_count, dry_run) and observe whether charge_extra_iterations
|
||||
was called.
|
||||
"""
|
||||
|
||||
calls: dict[str, list] = {
|
||||
"charge_extra_iterations": [],
|
||||
"handle_low_balance": [],
|
||||
"handle_insufficient_funds_notif": [],
|
||||
}
|
||||
|
||||
# Stub node lookup + DB client so the wrapper doesn't touch real infra.
|
||||
fake_db = MagicMock()
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2))
|
||||
monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db)
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: fake_db)
|
||||
# get_block is called by LogMetadata construction in on_node_execution.
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"get_block",
|
||||
lambda block_id: MagicMock(name="FakeBlock"),
|
||||
)
|
||||
# Persistence + cost logging are not under test here.
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"async_update_node_execution_status",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"async_update_graph_execution_state",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"log_system_credential_cost",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
|
||||
# Control the status returned by the inner execution call.
|
||||
inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3}
|
||||
|
||||
async def fake_inner(
|
||||
self,
|
||||
*,
|
||||
node,
|
||||
node_exec,
|
||||
node_exec_progress,
|
||||
stats,
|
||||
db_client,
|
||||
log_metadata,
|
||||
nodes_input_masks=None,
|
||||
nodes_to_skip=None,
|
||||
):
|
||||
stats.llm_call_count = inner_result["llm_call_count"]
|
||||
return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_on_node_execution",
|
||||
fake_inner,
|
||||
)
|
||||
|
||||
async def fake_charge_extra(self, node_exec, extra_iterations):
|
||||
calls["charge_extra_iterations"].append(extra_iterations)
|
||||
return (extra_iterations * 10, 500)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"charge_extra_iterations",
|
||||
fake_charge_extra,
|
||||
)
|
||||
|
||||
def fake_low_balance(self, db_client, user_id, current_balance, transaction_cost):
|
||||
calls["handle_low_balance"].append(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"current_balance": current_balance,
|
||||
"transaction_cost": transaction_cost,
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_handle_low_balance",
|
||||
fake_low_balance,
|
||||
)
|
||||
|
||||
def fake_notif(self, db_client, user_id, graph_id, e):
|
||||
calls["handle_insufficient_funds_notif"].append(
|
||||
{"user_id": user_id, "graph_id": graph_id, "error": e}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_handle_insufficient_funds_notif",
|
||||
fake_notif,
|
||||
)
|
||||
|
||||
return proc, calls, inner_result, fake_db, NodeExecutionStats
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_charges_extra_iterations_when_gate_passes(
|
||||
gated_processor,
|
||||
):
|
||||
"""COMPLETED + extra_credit_charges > 0 + not dry_run → charged."""
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 3 # → extra_charges = 2
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2))
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == [2]
|
||||
# _handle_low_balance must be called with the remaining balance returned by
|
||||
# charge_extra_iterations (500) so users are alerted when balance drops low.
|
||||
assert len(calls["handle_low_balance"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_status_not_completed(gated_processor):
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.FAILED
|
||||
inner["llm_call_count"] = 5
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4))
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_extra_charges_zero(gated_processor):
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 5
|
||||
# Block returns 0 extra charges (base class default)
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0))
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_skips_when_dry_run(gated_processor):
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 5
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4))
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=True),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_insufficient_balance_records_error_and_notifies(
|
||||
monkeypatch,
|
||||
gated_processor,
|
||||
):
|
||||
"""When extra-iteration charging fails with InsufficientBalanceError:
|
||||
|
||||
- the run still reports COMPLETED (the work is already done)
|
||||
- execution_stats.error is NOT set (would flip node_error_count and
|
||||
leak balance amounts into persisted node_stats — see manager.py
|
||||
comment in the IBE handler)
|
||||
- _handle_insufficient_funds_notif is called so the user is notified
|
||||
- the structured ERROR log is the alerting hook
|
||||
"""
|
||||
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 4
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3))
|
||||
|
||||
async def raise_ibe(self, node_exec, extra_iterations):
|
||||
raise InsufficientBalanceError(
|
||||
user_id=node_exec.user_id,
|
||||
message="Insufficient balance",
|
||||
balance=0,
|
||||
amount=extra_iterations * 10,
|
||||
)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "charge_extra_iterations", raise_ibe
|
||||
)
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
result_stats = await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
# error stays None — node ran to completion, only the post-hoc
|
||||
# charge failed. Setting .error would (a) flip node_error_count++
|
||||
# creating an "errored COMPLETED node" inconsistency, and (b) leak
|
||||
# balance amounts into persisted node_stats.
|
||||
assert result_stats.error is None
|
||||
# User notification fired.
|
||||
assert len(calls["handle_insufficient_funds_notif"]) == 1
|
||||
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
|
||||
|
||||
|
||||
# ── Orchestrator _execute_single_tool_with_manager charging gates ──
|
||||
|
||||
|
||||
async def _run_tool_exec_with_stats(
|
||||
*,
|
||||
dry_run: bool,
|
||||
tool_stats_error,
|
||||
charge_node_usage_mock=None,
|
||||
):
|
||||
"""Invoke _execute_single_tool_with_manager against fully mocked deps
|
||||
and return (charge_call_count, merge_stats_calls).
|
||||
|
||||
Used to prove the dry_run and error guards around charge_node_usage
|
||||
behave as documented, and that InsufficientBalanceError propagates.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
block = OrchestratorBlock()
|
||||
|
||||
# Mocked async DB client used inside orchestrator.
|
||||
mock_db_client = AsyncMock()
|
||||
mock_target_node = MagicMock()
|
||||
mock_target_node.block_id = "test-block-id"
|
||||
mock_target_node.input_default = {}
|
||||
mock_db_client.get_node.return_value = mock_target_node
|
||||
mock_node_exec_result = MagicMock()
|
||||
mock_node_exec_result.node_exec_id = "test-tool-exec-id"
|
||||
mock_db_client.upsert_execution_input.return_value = (
|
||||
mock_node_exec_result,
|
||||
{"query": "t"},
|
||||
)
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"}
|
||||
|
||||
# ExecutionProcessor mock: on_node_execution returns supplied error.
|
||||
mock_processor = AsyncMock()
|
||||
mock_processor.running_node_execution = defaultdict(MagicMock)
|
||||
mock_processor.execution_stats = MagicMock()
|
||||
mock_processor.execution_stats_lock = threading.Lock()
|
||||
mock_node_stats = MagicMock()
|
||||
mock_node_stats.error = tool_stats_error
|
||||
mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats)
|
||||
mock_processor.charge_node_usage = charge_node_usage_mock or AsyncMock(
|
||||
return_value=(10, 990)
|
||||
)
|
||||
|
||||
# Build a tool_info shaped like _build_tool_info_from_args output.
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "call-1"
|
||||
tool_call.name = "search_keywords"
|
||||
tool_call.arguments = '{"query":"t"}'
|
||||
tool_def = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_keywords",
|
||||
"_sink_node_id": "test-sink-node-id",
|
||||
"_field_mapping": {},
|
||||
"parameters": {
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
tool_info = OrchestratorBlock._build_tool_info_from_args(
|
||||
tool_call_id="call-1",
|
||||
tool_name="search_keywords",
|
||||
tool_args={"query": "t"},
|
||||
tool_def=tool_def,
|
||||
)
|
||||
|
||||
exec_params = ExecutionParams(
|
||||
user_id="u",
|
||||
graph_id="g",
|
||||
node_id="n",
|
||||
graph_version=1,
|
||||
graph_exec_id="ge",
|
||||
node_exec_id="ne",
|
||||
execution_context=ExecutionContext(
|
||||
human_in_the_loop_safe_mode=False, dry_run=dry_run
|
||||
),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.blocks.orchestrator.get_database_manager_async_client",
|
||||
return_value=mock_db_client,
|
||||
):
|
||||
try:
|
||||
await block._execute_single_tool_with_manager(
|
||||
tool_info, exec_params, mock_processor, responses_api=False
|
||||
)
|
||||
raised = None
|
||||
except Exception as e:
|
||||
raised = e
|
||||
|
||||
return mock_processor.charge_node_usage, raised
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_dry_run():
|
||||
"""dry_run=True → charge_node_usage is NOT called."""
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=True, tool_stats_error=None
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_failed_tool():
|
||||
"""tool_node_stats.error is an Exception → charge_node_usage NOT called."""
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False, tool_stats_error=RuntimeError("tool blew up")
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_skips_charging_on_cancelled_tool():
|
||||
"""Cancellation (BaseException subclass) → charge_node_usage NOT called.
|
||||
|
||||
Guards the fix for sentry's BaseException concern: the old
|
||||
`isinstance(error, Exception)` check would have treated CancelledError
|
||||
as "no error" and billed the user for a terminated run.
|
||||
"""
|
||||
import asyncio as _asyncio
|
||||
|
||||
charge_mock, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False, tool_stats_error=_asyncio.CancelledError()
|
||||
)
|
||||
assert raised is None
|
||||
assert charge_mock.call_count == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_insufficient_balance_propagates():
|
||||
"""InsufficientBalanceError from charge_node_usage must propagate out.
|
||||
|
||||
If this leaked into a ToolCallResult the LLM loop would keep running
|
||||
with 'tool failed' errors and the user would get unpaid work.
|
||||
"""
|
||||
raising_charge = AsyncMock(
|
||||
side_effect=InsufficientBalanceError(
|
||||
user_id="u", message="nope", balance=0, amount=10
|
||||
)
|
||||
)
|
||||
_, raised = await _run_tool_exec_with_stats(
|
||||
dry_run=False,
|
||||
tool_stats_error=None,
|
||||
charge_node_usage_mock=raising_charge,
|
||||
)
|
||||
assert isinstance(raised, InsufficientBalanceError)
|
||||
|
||||
|
||||
# ── on_node_execution FAILED + InsufficientBalanceError notification ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_failed_ibe_sends_notification(
|
||||
monkeypatch,
|
||||
gated_processor,
|
||||
):
|
||||
"""When status == FAILED and execution_stats.error is InsufficientBalanceError,
|
||||
_handle_insufficient_funds_notif must be called.
|
||||
|
||||
This path fires when a nested tool charge inside the orchestrator raises
|
||||
InsufficientBalanceError, which propagates out of the block's run() generator
|
||||
and is caught by _on_node_execution's broad except, setting status=FAILED and
|
||||
execution_stats.error=IBE. on_node_execution's post-execution block then
|
||||
sends the user notification so they understand why the run stopped.
|
||||
"""
|
||||
|
||||
proc, calls, inner, fake_db, NodeExecutionStats = gated_processor
|
||||
ibe = InsufficientBalanceError(
|
||||
user_id="u",
|
||||
message="Insufficient balance",
|
||||
balance=0,
|
||||
amount=30,
|
||||
)
|
||||
|
||||
# Simulate _on_node_execution returning FAILED with IBE in stats.error.
|
||||
async def fake_inner_failed(
|
||||
self,
|
||||
*,
|
||||
node,
|
||||
node_exec,
|
||||
node_exec_progress,
|
||||
stats,
|
||||
db_client,
|
||||
log_metadata,
|
||||
nodes_input_masks=None,
|
||||
nodes_to_skip=None,
|
||||
):
|
||||
stats.error = ibe
|
||||
return MagicMock(wall_time=0.1, cpu_time=0.1), ExecutionStatus.FAILED
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor,
|
||||
"_on_node_execution",
|
||||
fake_inner_failed,
|
||||
)
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0))
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
# The notification must have fired so the user knows why their run stopped.
|
||||
assert len(calls["handle_insufficient_funds_notif"]) == 1
|
||||
assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u"
|
||||
# charge_extra_iterations must NOT be called — status is FAILED.
|
||||
assert calls["charge_extra_iterations"] == []
|
||||
|
||||
|
||||
# ── Billing leak: non-IBE exception during extra-iteration charging ──
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_node_execution_non_ibe_billing_failure_keeps_completed(
|
||||
monkeypatch,
|
||||
gated_processor,
|
||||
):
|
||||
"""When charge_extra_iterations raises a non-IBE exception (e.g. DB outage):
|
||||
|
||||
- execution_stats.error stays None (node ran to completion)
|
||||
- status stays COMPLETED (work already done)
|
||||
- the billing_leak error is logged but does not corrupt execution_stats
|
||||
"""
|
||||
proc, calls, inner, fake_db, _ = gated_processor
|
||||
inner["status"] = ExecutionStatus.COMPLETED
|
||||
inner["llm_call_count"] = 4
|
||||
fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3))
|
||||
|
||||
async def raise_conn_error(self, node_exec, extra_iterations):
|
||||
raise ConnectionError("DB connection lost")
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "charge_extra_iterations", raise_conn_error
|
||||
)
|
||||
|
||||
stats_pair = (
|
||||
MagicMock(
|
||||
node_count=0,
|
||||
nodes_cputime=0,
|
||||
nodes_walltime=0,
|
||||
cost=0,
|
||||
node_error_count=0,
|
||||
),
|
||||
threading.Lock(),
|
||||
)
|
||||
result_stats = await proc.on_node_execution(
|
||||
node_exec=_make_node_exec(dry_run=False),
|
||||
node_exec_progress=MagicMock(),
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=stats_pair,
|
||||
)
|
||||
# error stays None — node completed, only billing failed.
|
||||
assert result_stats.error is None
|
||||
# No notification was sent (only IBE triggers notification).
|
||||
assert len(calls["handle_insufficient_funds_notif"]) == 0
|
||||
|
||||
|
||||
# ── _charge_usage with execution_count=0 ──
|
||||
|
||||
|
||||
class TestChargeUsageZeroExecutionCount:
|
||||
"""Verify _charge_usage(node_exec, 0) does not invoke execution_usage_cost."""
|
||||
|
||||
def test_execution_count_zero_skips_execution_tier(self, monkeypatch):
|
||||
"""_charge_usage with execution_count=0 must not call execution_usage_cost."""
|
||||
execution_tier_called = []
|
||||
|
||||
def fake_execution_usage_cost(count):
|
||||
execution_tier_called.append(count)
|
||||
return (100, count)
|
||||
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 500
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"block_usage_cost",
|
||||
lambda block, input_data, **_kw: (10, {}),
|
||||
)
|
||||
monkeypatch.setattr(manager, "execution_usage_cost", fake_execution_usage_cost)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
ne = MagicMock()
|
||||
ne.user_id = "u"
|
||||
ne.graph_exec_id = "ge"
|
||||
ne.graph_id = "g"
|
||||
ne.node_exec_id = "ne"
|
||||
ne.node_id = "n"
|
||||
ne.block_id = "b"
|
||||
ne.inputs = {}
|
||||
|
||||
total_cost, remaining = proc._charge_usage(ne, 0)
|
||||
assert total_cost == 10 # block cost only
|
||||
assert remaining == 500
|
||||
assert spent == [10]
|
||||
# execution_usage_cost must NOT have been called
|
||||
assert execution_tier_called == []
|
||||
@@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
ep.execution_stats_lock = threading.Lock()
|
||||
ns = MagicMock(error=None)
|
||||
ep.on_node_execution = AsyncMock(return_value=ns)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Must be AsyncMock because it is async and is awaited in
|
||||
# _execute_single_tool_with_manager — a plain MagicMock would return a
|
||||
# non-awaitable tuple and TypeError out, then be silently swallowed by
|
||||
# the orchestrator's catch-all.
|
||||
ep.charge_node_usage = AsyncMock(return_value=(0, 0))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
|
||||
@@ -19,7 +19,7 @@ from sentry_sdk.api import flush as _sentry_flush
|
||||
from sentry_sdk.api import get_current_scope as _sentry_get_current_scope
|
||||
|
||||
from backend.blocks import get_block
|
||||
from backend.blocks._base import BlockSchema
|
||||
from backend.blocks._base import Block, BlockSchema
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentOutputBlock
|
||||
from backend.blocks.mcp.block import MCPToolBlock
|
||||
@@ -681,6 +681,10 @@ class ExecutionProcessor:
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
|
||||
await self._handle_post_execution_billing(
|
||||
node, node_exec, execution_stats, status, log_metadata
|
||||
)
|
||||
|
||||
graph_stats, graph_stats_lock = graph_stats_pair
|
||||
with graph_stats_lock:
|
||||
graph_stats.node_count += 1 + execution_stats.extra_steps
|
||||
@@ -716,8 +720,121 @@ class ExecutionProcessor:
|
||||
db_client=db_client,
|
||||
)
|
||||
|
||||
# If the node failed because a nested tool charge raised IBE,
|
||||
# send the user notification so they understand why the run stopped.
|
||||
if status == ExecutionStatus.FAILED and isinstance(
|
||||
execution_stats.error, InsufficientBalanceError
|
||||
):
|
||||
await self._try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
execution_stats.error,
|
||||
log_metadata,
|
||||
)
|
||||
|
||||
return execution_stats
|
||||
|
||||
async def _try_send_insufficient_funds_notif(
|
||||
self,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
error: InsufficientBalanceError,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Send an insufficient-funds notification, swallowing failures."""
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._handle_insufficient_funds_notif,
|
||||
get_db_client(),
|
||||
user_id,
|
||||
graph_id,
|
||||
error,
|
||||
)
|
||||
except Exception as notif_error: # pragma: no cover
|
||||
log_metadata.warning(
|
||||
f"Failed to send insufficient funds notification: {notif_error}"
|
||||
)
|
||||
|
||||
async def _handle_post_execution_billing(
|
||||
self,
|
||||
node: Node,
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_stats: NodeExecutionStats,
|
||||
status: ExecutionStatus,
|
||||
log_metadata: LogMetadata,
|
||||
) -> None:
|
||||
"""Charge extra iterations for blocks that opt into per-LLM-call billing.
|
||||
|
||||
The first LLM call is already covered by ``_charge_usage()``; each
|
||||
additional call costs another ``base_cost``. Skipped for dry runs and
|
||||
failed runs.
|
||||
|
||||
InsufficientBalanceError here is a post-hoc billing leak: the work is
|
||||
already done but the user can no longer pay. The run stays COMPLETED and
|
||||
the error is logged with ``billing_leak: True`` for alerting.
|
||||
"""
|
||||
extra_iterations = (
|
||||
node.block.extra_credit_charges(execution_stats)
|
||||
if status == ExecutionStatus.COMPLETED
|
||||
and not node_exec.execution_context.dry_run
|
||||
else 0
|
||||
)
|
||||
if extra_iterations <= 0:
|
||||
return
|
||||
|
||||
try:
|
||||
extra_cost, remaining_balance = await self.charge_extra_iterations(
|
||||
node_exec,
|
||||
extra_iterations,
|
||||
)
|
||||
if extra_cost > 0:
|
||||
execution_stats.extra_cost += extra_cost
|
||||
await asyncio.to_thread(
|
||||
self._handle_low_balance,
|
||||
get_db_client(),
|
||||
node_exec.user_id,
|
||||
remaining_balance,
|
||||
extra_cost,
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
log_metadata.error(
|
||||
"billing_leak: insufficient balance after "
|
||||
f"{node.block.name} completed {extra_iterations} "
|
||||
f"extra iterations",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_iterations": extra_iterations,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
# Do NOT set execution_stats.error — the node ran to completion,
|
||||
# only the post-hoc charge failed. See class-level billing-leak
|
||||
# contract documentation.
|
||||
await self._try_send_insufficient_funds_notif(
|
||||
node_exec.user_id,
|
||||
node_exec.graph_id,
|
||||
e,
|
||||
log_metadata,
|
||||
)
|
||||
except Exception as e:
|
||||
log_metadata.error(
|
||||
f"billing_leak: failed to charge extra iterations "
|
||||
f"for {node.block.name}",
|
||||
extra={
|
||||
"billing_leak": True,
|
||||
"user_id": node_exec.user_id,
|
||||
"graph_id": node_exec.graph_id,
|
||||
"block_id": node_exec.block_id,
|
||||
"extra_iterations": extra_iterations,
|
||||
"error_type": type(e).__name__,
|
||||
"error": str(e),
|
||||
},
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
@async_time_measured
|
||||
async def _on_node_execution(
|
||||
self,
|
||||
@@ -944,6 +1061,27 @@ class ExecutionProcessor:
|
||||
stats=exec_stats,
|
||||
)
|
||||
|
||||
def _resolve_block_cost(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple[Block | None, int, dict]:
|
||||
"""Look up the block and compute its base usage cost for an exec.
|
||||
|
||||
Shared by :meth:`_charge_usage` and :meth:`charge_extra_iterations`
|
||||
so the (get_block, block_usage_cost) lookup lives in exactly one
|
||||
place. Returns ``(block, cost, matching_filter)``. ``block`` is
|
||||
``None`` if the block id can't be resolved — callers should treat
|
||||
that as "nothing to charge".
|
||||
"""
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return None, 0, {}
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
return block, cost, matching_filter
|
||||
|
||||
def _charge_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
@@ -952,14 +1090,10 @@ class ExecutionProcessor:
|
||||
total_cost = 0
|
||||
remaining_balance = 0
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
block, cost, matching_filter = self._resolve_block_cost(node_exec)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return total_cost, 0
|
||||
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -977,7 +1111,13 @@ class ExecutionProcessor:
|
||||
)
|
||||
total_cost += cost
|
||||
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
# execution_count=0 is used by charge_node_usage for nested tool calls
|
||||
# which must not be pushed into higher execution-count tiers.
|
||||
# execution_usage_cost(0) would trigger a charge because 0 % threshold == 0,
|
||||
# so skip it entirely when execution_count is 0.
|
||||
cost, usage_count = (
|
||||
execution_usage_cost(execution_count) if execution_count > 0 else (0, 0)
|
||||
)
|
||||
if cost > 0:
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -996,6 +1136,116 @@ class ExecutionProcessor:
|
||||
|
||||
return total_cost, remaining_balance
|
||||
|
||||
# Hard cap on the multiplier passed to charge_extra_iterations to
|
||||
# protect against a corrupted llm_call_count draining a user's balance.
|
||||
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
|
||||
# 200 leaves headroom while preventing runaway charges.
|
||||
_MAX_EXTRA_ITERATIONS = 200
|
||||
|
||||
def _charge_extra_iterations_sync(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
capped_iterations: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Synchronous implementation — runs in a thread-pool worker.
|
||||
|
||||
Called only from :meth:`charge_extra_iterations`. Do not call
|
||||
directly from async code.
|
||||
|
||||
Note: ``_resolve_block_cost`` is called again here (rather than
|
||||
reusing the result from ``_charge_usage`` at the start of execution)
|
||||
because the two calls happen in separate thread-pool workers and
|
||||
sharing mutable state across workers would require locks. The block
|
||||
config is immutable during a run, so the repeated lookup is safe and
|
||||
produces the same cost; the only overhead is an extra registry lookup.
|
||||
"""
|
||||
db_client = get_db_client()
|
||||
block, cost, matching_filter = self._resolve_block_cost(node_exec)
|
||||
if not block or cost <= 0:
|
||||
return 0, 0
|
||||
total_extra_cost = cost * capped_iterations
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=total_extra_cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
node_exec_id=node_exec.node_exec_id,
|
||||
node_id=node_exec.node_id,
|
||||
block_id=node_exec.block_id,
|
||||
block=block.name,
|
||||
input={
|
||||
**matching_filter,
|
||||
"extra_iterations": capped_iterations,
|
||||
},
|
||||
reason=(
|
||||
f"Extra agent-mode iterations for {block.name} "
|
||||
f"({capped_iterations} additional LLM calls)"
|
||||
),
|
||||
),
|
||||
)
|
||||
return total_extra_cost, remaining_balance
|
||||
|
||||
async def charge_extra_iterations(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_iterations: int,
|
||||
) -> tuple[int, int]:
|
||||
"""Charge a block extra iterations beyond the initial run.
|
||||
|
||||
Used by agent-mode blocks (e.g. OrchestratorBlock) that make
|
||||
multiple LLM calls within a single node execution. The first
|
||||
iteration is already charged by :meth:`_charge_usage`; this
|
||||
method charges *extra_iterations* additional copies of the
|
||||
block's base cost.
|
||||
|
||||
Returns ``(total_extra_cost, remaining_balance)``. May raise
|
||||
``InsufficientBalanceError`` if the user can't afford the charge.
|
||||
"""
|
||||
if extra_iterations <= 0:
|
||||
return 0, 0
|
||||
# Cap to protect against a corrupted llm_call_count.
|
||||
capped = min(extra_iterations, self._MAX_EXTRA_ITERATIONS)
|
||||
return await asyncio.to_thread(
|
||||
self._charge_extra_iterations_sync, node_exec, capped
|
||||
)
|
||||
|
||||
def _charge_and_check_balance(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple[int, int]:
|
||||
"""Charge usage and check low balance in a single thread-pool worker.
|
||||
|
||||
Combines ``_charge_usage`` and ``_handle_low_balance`` to avoid
|
||||
dispatching two thread-pool calls per tool execution.
|
||||
"""
|
||||
total_cost, remaining = self._charge_usage(node_exec, 0)
|
||||
if total_cost > 0:
|
||||
self._handle_low_balance(
|
||||
get_db_client(), node_exec.user_id, remaining, total_cost
|
||||
)
|
||||
return total_cost, remaining
|
||||
|
||||
async def charge_node_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> tuple[int, int]:
|
||||
"""Charge a single node execution to the user.
|
||||
|
||||
Public async wrapper around :meth:`_charge_usage` for blocks (e.g. the
|
||||
OrchestratorBlock) that spawn nested node executions outside the
|
||||
main queue and therefore need to charge them explicitly.
|
||||
|
||||
Also handles low-balance notification so callers don't need to touch
|
||||
private methods directly.
|
||||
|
||||
Note: this **does not** increment the global execution counter
|
||||
(``increment_execution_count``). Nested tool executions are
|
||||
sub-steps of a single block run from the user's perspective and
|
||||
should not push them into higher per-execution cost tiers.
|
||||
"""
|
||||
return await asyncio.to_thread(self._charge_and_check_balance, node_exec)
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user