diff --git a/autogpt_platform/backend/backend/blocks/_base.py b/autogpt_platform/backend/backend/blocks/_base.py index 56986d15c4..cbebdee85f 100644 --- a/autogpt_platform/backend/backend/blocks/_base.py +++ b/autogpt_platform/backend/backend/blocks/_base.py @@ -420,6 +420,19 @@ class BlockWebhookConfig(BlockManualWebhookConfig): class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): _optimized_description: ClassVar[str | None] = None + def extra_credit_charges(self, execution_stats: "NodeExecutionStats") -> int: + """Return extra credits to charge after this block run completes. + + Called by the executor after a block finishes with COMPLETED status. + The return value is the number of additional base-cost credits to + charge beyond the single credit already collected by ``_charge_usage`` + at the start of execution. Defaults to 0 (no extra charges). + + Override in blocks (e.g. OrchestratorBlock) that make multiple LLM + calls within one run and should be billed per call. + """ + return 0 + def __init__( self, id: str = "", diff --git a/autogpt_platform/backend/backend/blocks/orchestrator.py b/autogpt_platform/backend/backend/blocks/orchestrator.py index 6fbff643fb..3e956e8d6f 100644 --- a/autogpt_platform/backend/backend/blocks/orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/orchestrator.py @@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext from backend.data.model import NodeExecutionStats, SchemaField from backend.util import json from backend.util.clients import get_database_manager_async_client +from backend.util.exceptions import InsufficientBalanceError from backend.util.prompt import MAIN_OBJECTIVE_PREFIX from backend.util.security import SENSITIVE_FIELD_NAMES from backend.util.tool_call_loop import ( @@ -364,10 +365,25 @@ def _disambiguate_tool_names(tools: list[dict[str, Any]]) -> None: class OrchestratorBlock(Block): + """A block that uses a language model to orchestrate tool calls. + + Supports both single-shot and iterative agent mode execution. + + **InsufficientBalanceError propagation contract**: ``InsufficientBalanceError`` + (IBE) must always re-raise through every ``except`` block in this class. + Swallowing IBE would let the agent loop continue with unpaid work. Every + exception handler that catches ``Exception`` includes an explicit IBE + re-raise carve-out for this reason. """ - A block that uses a language model to orchestrate tool calls, supporting both - single-shot and iterative agent mode execution. - """ + + def extra_credit_charges(self, execution_stats: NodeExecutionStats) -> int: + """Charge one extra base credit per LLM call beyond the first. + + In agent mode each iteration makes one LLM call. The first is already + covered by _charge_usage(); this returns the number of additional + credits so the executor can bill the remaining calls post-completion. + """ + return max(0, execution_stats.llm_call_count - 1) # MCP server name used by the Claude Code SDK execution mode. Keep in sync # with _create_graph_mcp_server and the MCP_PREFIX derivation in _execute_tools_sdk_mode. @@ -1077,7 +1093,10 @@ class OrchestratorBlock(Block): input_data=input_value, ) - assert node_exec_result is not None, "node_exec_result should not be None" + if node_exec_result is None: + raise RuntimeError( + f"upsert_execution_input returned None for node {sink_node_id}" + ) # Create NodeExecutionEntry for execution manager node_exec_entry = NodeExecutionEntry( @@ -1112,15 +1131,79 @@ class OrchestratorBlock(Block): task=node_exec_future, ) - # Execute the node directly since we're in the Orchestrator context - node_exec_future.set_result( - await execution_processor.on_node_execution( + # Execute the node directly since we're in the Orchestrator context. + # Wrap in try/except so the future is always resolved, even on + # error — an unresolved Future would block anything awaiting it. + # + # on_node_execution is decorated with @async_error_logged(swallow=True), + # which catches BaseException and returns None rather than raising. + # Treat a None return as a failure: set_exception so the future + # carries an error state rather than a None result, and return an + # error response so the LLM knows the tool failed. + try: + tool_node_stats = await execution_processor.on_node_execution( node_exec=node_exec_entry, node_exec_progress=node_exec_progress, nodes_input_masks=None, graph_stats_pair=graph_stats_pair, ) - ) + if tool_node_stats is None: + nil_err = RuntimeError( + f"on_node_execution returned None for node {sink_node_id} " + "(error was swallowed by @async_error_logged)" + ) + node_exec_future.set_exception(nil_err) + resp = _create_tool_response( + tool_call.id, + "Tool execution returned no result", + responses_api=responses_api, + ) + resp["_is_error"] = True + return resp + node_exec_future.set_result(tool_node_stats) + except Exception as exec_err: + node_exec_future.set_exception(exec_err) + raise + + # Charge user credits AFTER successful tool execution. Tools + # spawned by the orchestrator bypass the main execution queue + # (where _charge_usage is called), so we must charge here to + # avoid free tool execution. Charging post-completion (vs. + # pre-execution) avoids billing users for failed tool calls. + # Skipped for dry runs. + # + # `error is None` intentionally excludes both Exception and + # BaseException subclasses (e.g. CancelledError) so cancelled + # or terminated tool runs are not billed. + # + # Billing errors (including non-balance exceptions) are kept + # in a separate try/except so they are never silently swallowed + # by the generic tool-error handler below. + if ( + not execution_params.execution_context.dry_run + and tool_node_stats.error is None + ): + try: + tool_cost, _ = await execution_processor.charge_node_usage( + node_exec_entry, + ) + except InsufficientBalanceError: + # IBE must propagate — see OrchestratorBlock class docstring. + raise + except Exception: + # Non-billing charge failures (DB outage, network, etc.) + # must NOT propagate to the outer except handler because + # the tool itself succeeded. Re-raising would mark the + # tool as failed (_is_error=True), causing the LLM to + # retry side-effectful operations. Log and continue. + logger.exception( + "Unexpected error charging for tool node %s; " + "tool execution was successful", + sink_node_id, + ) + tool_cost = 0 + if tool_cost > 0: + self.merge_stats(NodeExecutionStats(extra_cost=tool_cost)) # Get outputs from database after execution completes using database manager client node_outputs = await db_client.get_execution_outputs_by_node_exec_id( @@ -1133,18 +1216,26 @@ class OrchestratorBlock(Block): if node_outputs else "Tool executed successfully" ) - return _create_tool_response( + resp = _create_tool_response( tool_call.id, tool_response_content, responses_api=responses_api ) + resp["_is_error"] = False + return resp + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: - logger.warning("Tool execution with manager failed: %s", e) - # Return error response - return _create_tool_response( + logger.warning("Tool execution with manager failed: %s", e, exc_info=True) + # Return a generic error to the LLM — internal exception messages + # may contain server paths, DB details, or infrastructure info. + resp = _create_tool_response( tool_call.id, - f"Tool execution failed: {e}", + "Tool execution failed due to an internal error", responses_api=responses_api, ) + resp["_is_error"] = True + return resp async def _agent_mode_llm_caller( self, @@ -1244,13 +1335,16 @@ class OrchestratorBlock(Block): content = str(raw_content) else: content = "Tool executed successfully" - tool_failed = content.startswith("Tool execution failed:") + tool_failed = result.get("_is_error", True) return ToolCallResult( tool_call_id=tool_call.id, tool_name=tool_call.name, content=content, is_error=tool_failed, ) + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: logger.error("Tool execution failed: %s", e) return ToolCallResult( @@ -1370,9 +1464,13 @@ class OrchestratorBlock(Block): "arguments": tc.arguments, }, ) + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: - # Catch all errors (validation, network, API) so that the block - # surfaces them as user-visible output instead of crashing. + # Catch all OTHER errors (validation, network, API) so that + # the block surfaces them as user-visible output instead of + # crashing. yield "error", str(e) return @@ -1450,11 +1548,14 @@ class OrchestratorBlock(Block): text = content else: text = json.dumps(content) - tool_failed = text.startswith("Tool execution failed:") + tool_failed = result.get("_is_error", True) return { "content": [{"type": "text", "text": text}], "isError": tool_failed, } + except InsufficientBalanceError: + # IBE must propagate — see class docstring. + raise except Exception as e: logger.error("SDK tool execution failed: %s", e) return { @@ -1733,11 +1834,15 @@ class OrchestratorBlock(Block): await pending_task except (asyncio.CancelledError, StopAsyncIteration): pass + except InsufficientBalanceError: + # IBE must propagate — see class docstring. The `finally` + # block below still runs and records partial token usage. + raise except Exception as e: - # Surface SDK errors as user-visible output instead of crashing, - # consistent with _execute_tools_agent_mode error handling. - # Don't return yet — fall through to merge_stats below so - # partial token usage is always recorded. + # Surface OTHER SDK errors as user-visible output instead + # of crashing, consistent with _execute_tools_agent_mode + # error handling. Don't return yet — fall through to + # merge_stats below so partial token usage is always recorded. sdk_error = e finally: # Always record usage stats, even on error. The SDK may have diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py index 55f137428f..2eb27012dc 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator.py @@ -922,6 +922,11 @@ async def test_orchestrator_agent_mode(): mock_execution_processor.on_node_execution = AsyncMock( return_value=mock_node_stats ) + # Mock charge_node_usage (called after successful tool execution). + # Returns (cost, remaining_balance). Must be AsyncMock because it is + # an async method and is directly awaited in _execute_single_tool_with_manager. + # Use a non-zero cost so the merge_stats branch is exercised. + mock_execution_processor.charge_node_usage = AsyncMock(return_value=(10, 990)) # Mock the get_execution_outputs_by_node_exec_id method mock_db_client.get_execution_outputs_by_node_exec_id.return_value = { @@ -967,6 +972,11 @@ async def test_orchestrator_agent_mode(): # Verify tool was executed via execution processor assert mock_execution_processor.on_node_execution.call_count == 1 + # Verify charge_node_usage was actually called for the successful + # tool execution — this guards against regressions where the + # post-execution tool charging is accidentally removed. + assert mock_execution_processor.charge_node_usage.call_count == 1 + @pytest.mark.asyncio async def test_orchestrator_traditional_mode_default(): diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py index 1069fc8ad5..f2242ea527 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_dynamic_fields.py @@ -641,6 +641,14 @@ async def test_validation_errors_dont_pollute_conversation(): mock_execution_processor.on_node_execution.return_value = ( mock_node_stats ) + # Mock charge_node_usage (called after successful tool execution). + # Must be AsyncMock because it is async and is awaited in + # _execute_single_tool_with_manager — a plain MagicMock would + # return a non-awaitable tuple and TypeError out, then be + # silently swallowed by the orchestrator's catch-all. + mock_execution_processor.charge_node_usage = AsyncMock( + return_value=(0, 0) + ) async for output_name, output_value in block.run( input_data, diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py new file mode 100644 index 0000000000..3dc9e9b9ae --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_per_iteration_cost.py @@ -0,0 +1,973 @@ +"""Tests for OrchestratorBlock per-iteration cost charging. + +The OrchestratorBlock in agent mode makes multiple LLM calls in a single +node execution. The executor uses ``Block.extra_credit_charges`` to detect +this and charge ``base_cost * (llm_call_count - 1)`` extra credits after +the block completes. +""" + +import threading +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from backend.blocks._base import Block +from backend.blocks.orchestrator import ExecutionParams, OrchestratorBlock +from backend.data.execution import ExecutionContext, ExecutionStatus +from backend.data.model import NodeExecutionStats +from backend.executor import manager +from backend.util.exceptions import InsufficientBalanceError + +# ── extra_credit_charges hook ──────────────────────────────────────── + + +class _NoOpBlock(Block): + """Minimal concrete Block subclass that does not override extra_credit_charges.""" + + def __init__(self): + super().__init__(id="noop-block", description="No-op test block") + + def run(self, input_data, **kwargs): # type: ignore[override] + yield "out", {} + + +class TestExtraCreditCharges: + """OrchestratorBlock opts into per-LLM-call billing via extra_credit_charges.""" + + def test_orchestrator_returns_nonzero_for_multiple_calls(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=3) + assert block.extra_credit_charges(stats) == 2 + + def test_orchestrator_returns_zero_for_single_call(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=1) + assert block.extra_credit_charges(stats) == 0 + + def test_orchestrator_returns_zero_for_zero_calls(self): + block = OrchestratorBlock() + stats = NodeExecutionStats(llm_call_count=0) + assert block.extra_credit_charges(stats) == 0 + + def test_default_block_returns_zero(self): + """A block that does not override extra_credit_charges returns 0.""" + block = _NoOpBlock() + stats = NodeExecutionStats(llm_call_count=10) + assert block.extra_credit_charges(stats) == 0 + + +# ── charge_extra_iterations math ─────────────────────────────────── + + +@pytest.fixture() +def fake_node_exec(): + node_exec = MagicMock() + node_exec.user_id = "u" + node_exec.graph_exec_id = "g" + node_exec.graph_id = "g" + node_exec.node_exec_id = "ne" + node_exec.node_id = "n" + node_exec.block_id = "b" + node_exec.inputs = {} + return node_exec + + +@pytest.fixture() +def patched_processor(monkeypatch): + """ExecutionProcessor with stubbed db client / block lookup helpers. + + Returns the processor and a list of credit amounts spent so tests can + assert on what was charged. + + Note: ``ExecutionProcessor.__new__()`` bypasses ``__init__`` — if + ``__init__`` gains required state in the future this fixture will need + updating. + """ + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 1000 # remaining balance + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + manager, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + return proc, spent + + +class TestChargeExtraIterations: + @pytest.mark.asyncio + async def test_zero_extra_iterations_charges_nothing( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=0 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_extra_iterations_multiplies_base_cost( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=4 + ) + assert cost == 40 # 4 × 10 + assert balance == 1000 + assert spent == [40] + + @pytest.mark.asyncio + async def test_negative_extra_iterations_charges_nothing( + self, patched_processor, fake_node_exec + ): + proc, spent = patched_processor + cost, balance = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=-1 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_capped_at_max(self, monkeypatch, fake_node_exec): + """Runaway llm_call_count is capped at _MAX_EXTRA_ITERATIONS.""" + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 1000 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + manager, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {}), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cap = manager.ExecutionProcessor._MAX_EXTRA_ITERATIONS + cost, _ = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=cap * 100 + ) + # Charged at most cap × 10 + assert cost == cap * 10 + assert spent == [cap * 10] + + @pytest.mark.asyncio + async def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec): + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 0 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + manager, "block_usage_cost", lambda block, input_data, **_kw: (0, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=4 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec): + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 0 + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: None) + monkeypatch.setattr( + manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_extra_iterations( + fake_node_exec, extra_iterations=3 + ) + assert cost == 0 + assert balance == 0 + assert spent == [] + + @pytest.mark.asyncio + async def test_propagates_insufficient_balance_error( + self, monkeypatch, fake_node_exec + ): + """Out-of-credits errors must propagate, not be silently swallowed.""" + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + raise InsufficientBalanceError( + user_id=user_id, + message="Insufficient balance", + balance=0, + amount=cost, + ) + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {}) + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + with pytest.raises(InsufficientBalanceError): + await proc.charge_extra_iterations(fake_node_exec, extra_iterations=4) + + +# ── charge_node_usage ────────────────────────────────────────────── + + +class TestChargeNodeUsage: + """charge_node_usage delegates to _charge_usage with execution_count=0.""" + + @pytest.mark.asyncio + async def test_delegates_with_zero_execution_count( + self, monkeypatch, fake_node_exec + ): + """Nested tool charges should NOT inflate the per-execution counter.""" + + captured: dict = {} + + def fake_charge_usage(self, node_exec, execution_count): + captured["execution_count"] = execution_count + captured["node_exec"] = node_exec + return (5, 100) + + def fake_handle_low_balance( + self, db_client, user_id, current_balance, transaction_cost + ): + pass + + monkeypatch.setattr( + manager.ExecutionProcessor, "_charge_usage", fake_charge_usage + ) + monkeypatch.setattr( + manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance + ) + monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 5 + assert balance == 100 + assert captured["execution_count"] == 0 + + @pytest.mark.asyncio + async def test_calls_handle_low_balance_when_cost_nonzero( + self, monkeypatch, fake_node_exec + ): + """charge_node_usage should call _handle_low_balance when total_cost > 0.""" + + low_balance_calls: list[dict] = [] + + def fake_charge_usage(self, node_exec, execution_count): + return (10, 50) + + def fake_handle_low_balance( + self, db_client, user_id, current_balance, transaction_cost + ): + low_balance_calls.append( + { + "user_id": user_id, + "current_balance": current_balance, + "transaction_cost": transaction_cost, + } + ) + + monkeypatch.setattr( + manager.ExecutionProcessor, "_charge_usage", fake_charge_usage + ) + monkeypatch.setattr( + manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance + ) + monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 10 + assert balance == 50 + assert len(low_balance_calls) == 1 + assert low_balance_calls[0]["user_id"] == "u" + assert low_balance_calls[0]["current_balance"] == 50 + assert low_balance_calls[0]["transaction_cost"] == 10 + + @pytest.mark.asyncio + async def test_skips_handle_low_balance_when_cost_zero( + self, monkeypatch, fake_node_exec + ): + """charge_node_usage should NOT call _handle_low_balance when cost is 0.""" + + low_balance_calls: list = [] + + def fake_charge_usage(self, node_exec, execution_count): + return (0, 200) + + def fake_handle_low_balance( + self, db_client, user_id, current_balance, transaction_cost + ): + low_balance_calls.append(True) + + monkeypatch.setattr( + manager.ExecutionProcessor, "_charge_usage", fake_charge_usage + ) + monkeypatch.setattr( + manager.ExecutionProcessor, "_handle_low_balance", fake_handle_low_balance + ) + monkeypatch.setattr(manager, "get_db_client", lambda: MagicMock()) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + cost, balance = await proc.charge_node_usage(fake_node_exec) + assert cost == 0 + assert low_balance_calls == [] + + +# ── on_node_execution charging gate ──────────────────────────────── + + +class _FakeNode: + """Minimal stand-in for a ``Node`` object with a block attribute.""" + + def __init__(self, extra_charges: int = 0, block_name: str = "FakeBlock"): + self.block = MagicMock() + self.block.name = block_name + self.block.extra_credit_charges = MagicMock(return_value=extra_charges) + + +class _FakeExecContext: + def __init__(self, dry_run: bool = False): + self.dry_run = dry_run + + +def _make_node_exec(dry_run: bool = False) -> MagicMock: + """Build a NodeExecutionEntry-like mock for on_node_execution tests.""" + ne = MagicMock() + ne.user_id = "u" + ne.graph_id = "g" + ne.graph_exec_id = "ge" + ne.node_id = "n" + ne.node_exec_id = "ne" + ne.block_id = "b" + ne.inputs = {} + ne.execution_context = _FakeExecContext(dry_run=dry_run) + return ne + + +@pytest.fixture() +def gated_processor(monkeypatch): + """ExecutionProcessor with on_node_execution's downstream calls stubbed. + + Lets tests flip the gate conditions (status, extra_credit_charges result, + llm_call_count, dry_run) and observe whether charge_extra_iterations + was called. + """ + + calls: dict[str, list] = { + "charge_extra_iterations": [], + "handle_low_balance": [], + "handle_insufficient_funds_notif": [], + } + + # Stub node lookup + DB client so the wrapper doesn't touch real infra. + fake_db = MagicMock() + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) + monkeypatch.setattr(manager, "get_db_async_client", lambda: fake_db) + monkeypatch.setattr(manager, "get_db_client", lambda: fake_db) + # get_block is called by LogMetadata construction in on_node_execution. + monkeypatch.setattr( + manager, + "get_block", + lambda block_id: MagicMock(name="FakeBlock"), + ) + # Persistence + cost logging are not under test here. + monkeypatch.setattr( + manager, + "async_update_node_execution_status", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + manager, + "async_update_graph_execution_state", + AsyncMock(return_value=None), + ) + monkeypatch.setattr( + manager, + "log_system_credential_cost", + AsyncMock(return_value=None), + ) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + + # Control the status returned by the inner execution call. + inner_result = {"status": ExecutionStatus.COMPLETED, "llm_call_count": 3} + + async def fake_inner( + self, + *, + node, + node_exec, + node_exec_progress, + stats, + db_client, + log_metadata, + nodes_input_masks=None, + nodes_to_skip=None, + ): + stats.llm_call_count = inner_result["llm_call_count"] + return MagicMock(wall_time=0.1, cpu_time=0.1), inner_result["status"] + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_on_node_execution", + fake_inner, + ) + + async def fake_charge_extra(self, node_exec, extra_iterations): + calls["charge_extra_iterations"].append(extra_iterations) + return (extra_iterations * 10, 500) + + monkeypatch.setattr( + manager.ExecutionProcessor, + "charge_extra_iterations", + fake_charge_extra, + ) + + def fake_low_balance(self, db_client, user_id, current_balance, transaction_cost): + calls["handle_low_balance"].append( + { + "user_id": user_id, + "current_balance": current_balance, + "transaction_cost": transaction_cost, + } + ) + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_handle_low_balance", + fake_low_balance, + ) + + def fake_notif(self, db_client, user_id, graph_id, e): + calls["handle_insufficient_funds_notif"].append( + {"user_id": user_id, "graph_id": graph_id, "error": e} + ) + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_handle_insufficient_funds_notif", + fake_notif, + ) + + return proc, calls, inner_result, fake_db, NodeExecutionStats + + +@pytest.mark.asyncio +async def test_on_node_execution_charges_extra_iterations_when_gate_passes( + gated_processor, +): + """COMPLETED + extra_credit_charges > 0 + not dry_run → charged.""" + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 3 # → extra_charges = 2 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=2)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_iterations"] == [2] + # _handle_low_balance must be called with the remaining balance returned by + # charge_extra_iterations (500) so users are alerted when balance drops low. + assert len(calls["handle_low_balance"]) == 1 + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_status_not_completed(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.FAILED + inner["llm_call_count"] = 5 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_iterations"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_extra_charges_zero(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 5 + # Block returns 0 extra charges (base class default) + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_iterations"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_skips_when_dry_run(gated_processor): + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 5 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=4)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=True), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + assert calls["charge_extra_iterations"] == [] + + +@pytest.mark.asyncio +async def test_on_node_execution_insufficient_balance_records_error_and_notifies( + monkeypatch, + gated_processor, +): + """When extra-iteration charging fails with InsufficientBalanceError: + + - the run still reports COMPLETED (the work is already done) + - execution_stats.error is NOT set (would flip node_error_count and + leak balance amounts into persisted node_stats — see manager.py + comment in the IBE handler) + - _handle_insufficient_funds_notif is called so the user is notified + - the structured ERROR log is the alerting hook + """ + + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 4 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) + + async def raise_ibe(self, node_exec, extra_iterations): + raise InsufficientBalanceError( + user_id=node_exec.user_id, + message="Insufficient balance", + balance=0, + amount=extra_iterations * 10, + ) + + monkeypatch.setattr( + manager.ExecutionProcessor, "charge_extra_iterations", raise_ibe + ) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + result_stats = await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # error stays None — node ran to completion, only the post-hoc + # charge failed. Setting .error would (a) flip node_error_count++ + # creating an "errored COMPLETED node" inconsistency, and (b) leak + # balance amounts into persisted node_stats. + assert result_stats.error is None + # User notification fired. + assert len(calls["handle_insufficient_funds_notif"]) == 1 + assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" + + +# ── Orchestrator _execute_single_tool_with_manager charging gates ── + + +async def _run_tool_exec_with_stats( + *, + dry_run: bool, + tool_stats_error, + charge_node_usage_mock=None, +): + """Invoke _execute_single_tool_with_manager against fully mocked deps + and return (charge_call_count, merge_stats_calls). + + Used to prove the dry_run and error guards around charge_node_usage + behave as documented, and that InsufficientBalanceError propagates. + """ + from collections import defaultdict + from unittest.mock import AsyncMock, MagicMock, patch + + block = OrchestratorBlock() + + # Mocked async DB client used inside orchestrator. + mock_db_client = AsyncMock() + mock_target_node = MagicMock() + mock_target_node.block_id = "test-block-id" + mock_target_node.input_default = {} + mock_db_client.get_node.return_value = mock_target_node + mock_node_exec_result = MagicMock() + mock_node_exec_result.node_exec_id = "test-tool-exec-id" + mock_db_client.upsert_execution_input.return_value = ( + mock_node_exec_result, + {"query": "t"}, + ) + mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {"result": "ok"} + + # ExecutionProcessor mock: on_node_execution returns supplied error. + mock_processor = AsyncMock() + mock_processor.running_node_execution = defaultdict(MagicMock) + mock_processor.execution_stats = MagicMock() + mock_processor.execution_stats_lock = threading.Lock() + mock_node_stats = MagicMock() + mock_node_stats.error = tool_stats_error + mock_processor.on_node_execution = AsyncMock(return_value=mock_node_stats) + mock_processor.charge_node_usage = charge_node_usage_mock or AsyncMock( + return_value=(10, 990) + ) + + # Build a tool_info shaped like _build_tool_info_from_args output. + tool_call = MagicMock() + tool_call.id = "call-1" + tool_call.name = "search_keywords" + tool_call.arguments = '{"query":"t"}' + tool_def = { + "type": "function", + "function": { + "name": "search_keywords", + "_sink_node_id": "test-sink-node-id", + "_field_mapping": {}, + "parameters": { + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + tool_info = OrchestratorBlock._build_tool_info_from_args( + tool_call_id="call-1", + tool_name="search_keywords", + tool_args={"query": "t"}, + tool_def=tool_def, + ) + + exec_params = ExecutionParams( + user_id="u", + graph_id="g", + node_id="n", + graph_version=1, + graph_exec_id="ge", + node_exec_id="ne", + execution_context=ExecutionContext( + human_in_the_loop_safe_mode=False, dry_run=dry_run + ), + ) + + with patch( + "backend.blocks.orchestrator.get_database_manager_async_client", + return_value=mock_db_client, + ): + try: + await block._execute_single_tool_with_manager( + tool_info, exec_params, mock_processor, responses_api=False + ) + raised = None + except Exception as e: + raised = e + + return mock_processor.charge_node_usage, raised + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_dry_run(): + """dry_run=True → charge_node_usage is NOT called.""" + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=True, tool_stats_error=None + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_failed_tool(): + """tool_node_stats.error is an Exception → charge_node_usage NOT called.""" + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=False, tool_stats_error=RuntimeError("tool blew up") + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_skips_charging_on_cancelled_tool(): + """Cancellation (BaseException subclass) → charge_node_usage NOT called. + + Guards the fix for sentry's BaseException concern: the old + `isinstance(error, Exception)` check would have treated CancelledError + as "no error" and billed the user for a terminated run. + """ + import asyncio as _asyncio + + charge_mock, raised = await _run_tool_exec_with_stats( + dry_run=False, tool_stats_error=_asyncio.CancelledError() + ) + assert raised is None + assert charge_mock.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_execution_insufficient_balance_propagates(): + """InsufficientBalanceError from charge_node_usage must propagate out. + + If this leaked into a ToolCallResult the LLM loop would keep running + with 'tool failed' errors and the user would get unpaid work. + """ + raising_charge = AsyncMock( + side_effect=InsufficientBalanceError( + user_id="u", message="nope", balance=0, amount=10 + ) + ) + _, raised = await _run_tool_exec_with_stats( + dry_run=False, + tool_stats_error=None, + charge_node_usage_mock=raising_charge, + ) + assert isinstance(raised, InsufficientBalanceError) + + +# ── on_node_execution FAILED + InsufficientBalanceError notification ── + + +@pytest.mark.asyncio +async def test_on_node_execution_failed_ibe_sends_notification( + monkeypatch, + gated_processor, +): + """When status == FAILED and execution_stats.error is InsufficientBalanceError, + _handle_insufficient_funds_notif must be called. + + This path fires when a nested tool charge inside the orchestrator raises + InsufficientBalanceError, which propagates out of the block's run() generator + and is caught by _on_node_execution's broad except, setting status=FAILED and + execution_stats.error=IBE. on_node_execution's post-execution block then + sends the user notification so they understand why the run stopped. + """ + + proc, calls, inner, fake_db, NodeExecutionStats = gated_processor + ibe = InsufficientBalanceError( + user_id="u", + message="Insufficient balance", + balance=0, + amount=30, + ) + + # Simulate _on_node_execution returning FAILED with IBE in stats.error. + async def fake_inner_failed( + self, + *, + node, + node_exec, + node_exec_progress, + stats, + db_client, + log_metadata, + nodes_input_masks=None, + nodes_to_skip=None, + ): + stats.error = ibe + return MagicMock(wall_time=0.1, cpu_time=0.1), ExecutionStatus.FAILED + + monkeypatch.setattr( + manager.ExecutionProcessor, + "_on_node_execution", + fake_inner_failed, + ) + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=0)) + + stats_pair = ( + MagicMock( + node_count=0, nodes_cputime=0, nodes_walltime=0, cost=0, node_error_count=0 + ), + threading.Lock(), + ) + await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # The notification must have fired so the user knows why their run stopped. + assert len(calls["handle_insufficient_funds_notif"]) == 1 + assert calls["handle_insufficient_funds_notif"][0]["user_id"] == "u" + # charge_extra_iterations must NOT be called — status is FAILED. + assert calls["charge_extra_iterations"] == [] + + +# ── Billing leak: non-IBE exception during extra-iteration charging ── + + +@pytest.mark.asyncio +async def test_on_node_execution_non_ibe_billing_failure_keeps_completed( + monkeypatch, + gated_processor, +): + """When charge_extra_iterations raises a non-IBE exception (e.g. DB outage): + + - execution_stats.error stays None (node ran to completion) + - status stays COMPLETED (work already done) + - the billing_leak error is logged but does not corrupt execution_stats + """ + proc, calls, inner, fake_db, _ = gated_processor + inner["status"] = ExecutionStatus.COMPLETED + inner["llm_call_count"] = 4 + fake_db.get_node = AsyncMock(return_value=_FakeNode(extra_charges=3)) + + async def raise_conn_error(self, node_exec, extra_iterations): + raise ConnectionError("DB connection lost") + + monkeypatch.setattr( + manager.ExecutionProcessor, "charge_extra_iterations", raise_conn_error + ) + + stats_pair = ( + MagicMock( + node_count=0, + nodes_cputime=0, + nodes_walltime=0, + cost=0, + node_error_count=0, + ), + threading.Lock(), + ) + result_stats = await proc.on_node_execution( + node_exec=_make_node_exec(dry_run=False), + node_exec_progress=MagicMock(), + nodes_input_masks=None, + graph_stats_pair=stats_pair, + ) + # error stays None — node completed, only billing failed. + assert result_stats.error is None + # No notification was sent (only IBE triggers notification). + assert len(calls["handle_insufficient_funds_notif"]) == 0 + + +# ── _charge_usage with execution_count=0 ── + + +class TestChargeUsageZeroExecutionCount: + """Verify _charge_usage(node_exec, 0) does not invoke execution_usage_cost.""" + + def test_execution_count_zero_skips_execution_tier(self, monkeypatch): + """_charge_usage with execution_count=0 must not call execution_usage_cost.""" + execution_tier_called = [] + + def fake_execution_usage_cost(count): + execution_tier_called.append(count) + return (100, count) + + spent: list[int] = [] + + class FakeDb: + def spend_credits(self, *, user_id, cost, metadata): + spent.append(cost) + return 500 + + fake_block = MagicMock() + fake_block.name = "FakeBlock" + + monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb()) + monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block) + monkeypatch.setattr( + manager, + "block_usage_cost", + lambda block, input_data, **_kw: (10, {}), + ) + monkeypatch.setattr(manager, "execution_usage_cost", fake_execution_usage_cost) + + proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor) + ne = MagicMock() + ne.user_id = "u" + ne.graph_exec_id = "ge" + ne.graph_id = "g" + ne.node_exec_id = "ne" + ne.node_id = "n" + ne.block_id = "b" + ne.inputs = {} + + total_cost, remaining = proc._charge_usage(ne, 0) + assert total_cost == 10 # block cost only + assert remaining == 500 + assert spent == [10] + # execution_usage_cost must NOT have been called + assert execution_tier_called == [] diff --git a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py index f9ec7676ba..ac78b6d35b 100644 --- a/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py +++ b/autogpt_platform/backend/backend/blocks/test/test_orchestrator_responses_api.py @@ -956,6 +956,12 @@ async def test_agent_mode_conversation_valid_for_responses_api(): ep.execution_stats_lock = threading.Lock() ns = MagicMock(error=None) ep.on_node_execution = AsyncMock(return_value=ns) + # Mock charge_node_usage (called after successful tool execution). + # Must be AsyncMock because it is async and is awaited in + # _execute_single_tool_with_manager — a plain MagicMock would return a + # non-awaitable tuple and TypeError out, then be silently swallowed by + # the orchestrator's catch-all. + ep.charge_node_usage = AsyncMock(return_value=(0, 0)) with patch("backend.blocks.llm.llm_call", llm_mock), patch.object( block, "_create_tool_node_signatures", return_value=tool_sigs diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index bd718d168f..6e4b3e2aad 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -19,7 +19,7 @@ from sentry_sdk.api import flush as _sentry_flush from sentry_sdk.api import get_current_scope as _sentry_get_current_scope from backend.blocks import get_block -from backend.blocks._base import BlockSchema +from backend.blocks._base import Block, BlockSchema from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentOutputBlock from backend.blocks.mcp.block import MCPToolBlock @@ -681,6 +681,10 @@ class ExecutionProcessor: execution_stats.walltime = timing_info.wall_time execution_stats.cputime = timing_info.cpu_time + await self._handle_post_execution_billing( + node, node_exec, execution_stats, status, log_metadata + ) + graph_stats, graph_stats_lock = graph_stats_pair with graph_stats_lock: graph_stats.node_count += 1 + execution_stats.extra_steps @@ -716,8 +720,121 @@ class ExecutionProcessor: db_client=db_client, ) + # If the node failed because a nested tool charge raised IBE, + # send the user notification so they understand why the run stopped. + if status == ExecutionStatus.FAILED and isinstance( + execution_stats.error, InsufficientBalanceError + ): + await self._try_send_insufficient_funds_notif( + node_exec.user_id, + node_exec.graph_id, + execution_stats.error, + log_metadata, + ) + return execution_stats + async def _try_send_insufficient_funds_notif( + self, + user_id: str, + graph_id: str, + error: InsufficientBalanceError, + log_metadata: LogMetadata, + ) -> None: + """Send an insufficient-funds notification, swallowing failures.""" + try: + await asyncio.to_thread( + self._handle_insufficient_funds_notif, + get_db_client(), + user_id, + graph_id, + error, + ) + except Exception as notif_error: # pragma: no cover + log_metadata.warning( + f"Failed to send insufficient funds notification: {notif_error}" + ) + + async def _handle_post_execution_billing( + self, + node: Node, + node_exec: NodeExecutionEntry, + execution_stats: NodeExecutionStats, + status: ExecutionStatus, + log_metadata: LogMetadata, + ) -> None: + """Charge extra iterations for blocks that opt into per-LLM-call billing. + + The first LLM call is already covered by ``_charge_usage()``; each + additional call costs another ``base_cost``. Skipped for dry runs and + failed runs. + + InsufficientBalanceError here is a post-hoc billing leak: the work is + already done but the user can no longer pay. The run stays COMPLETED and + the error is logged with ``billing_leak: True`` for alerting. + """ + extra_iterations = ( + node.block.extra_credit_charges(execution_stats) + if status == ExecutionStatus.COMPLETED + and not node_exec.execution_context.dry_run + else 0 + ) + if extra_iterations <= 0: + return + + try: + extra_cost, remaining_balance = await self.charge_extra_iterations( + node_exec, + extra_iterations, + ) + if extra_cost > 0: + execution_stats.extra_cost += extra_cost + await asyncio.to_thread( + self._handle_low_balance, + get_db_client(), + node_exec.user_id, + remaining_balance, + extra_cost, + ) + except InsufficientBalanceError as e: + log_metadata.error( + "billing_leak: insufficient balance after " + f"{node.block.name} completed {extra_iterations} " + f"extra iterations", + extra={ + "billing_leak": True, + "user_id": node_exec.user_id, + "graph_id": node_exec.graph_id, + "block_id": node_exec.block_id, + "extra_iterations": extra_iterations, + "error": str(e), + }, + ) + # Do NOT set execution_stats.error — the node ran to completion, + # only the post-hoc charge failed. See class-level billing-leak + # contract documentation. + await self._try_send_insufficient_funds_notif( + node_exec.user_id, + node_exec.graph_id, + e, + log_metadata, + ) + except Exception as e: + log_metadata.error( + f"billing_leak: failed to charge extra iterations " + f"for {node.block.name}", + extra={ + "billing_leak": True, + "user_id": node_exec.user_id, + "graph_id": node_exec.graph_id, + "block_id": node_exec.block_id, + "extra_iterations": extra_iterations, + "error_type": type(e).__name__, + "error": str(e), + }, + exc_info=True, + ) + @async_time_measured async def _on_node_execution( self, @@ -944,6 +1061,27 @@ class ExecutionProcessor: stats=exec_stats, ) + def _resolve_block_cost( + self, + node_exec: NodeExecutionEntry, + ) -> tuple[Block | None, int, dict]: + """Look up the block and compute its base usage cost for an exec. + + Shared by :meth:`_charge_usage` and :meth:`charge_extra_iterations` + so the (get_block, block_usage_cost) lookup lives in exactly one + place. Returns ``(block, cost, matching_filter)``. ``block`` is + ``None`` if the block id can't be resolved — callers should treat + that as "nothing to charge". + """ + block = get_block(node_exec.block_id) + if not block: + logger.error(f"Block {node_exec.block_id} not found.") + return None, 0, {} + cost, matching_filter = block_usage_cost( + block=block, input_data=node_exec.inputs + ) + return block, cost, matching_filter + def _charge_usage( self, node_exec: NodeExecutionEntry, @@ -952,14 +1090,10 @@ class ExecutionProcessor: total_cost = 0 remaining_balance = 0 db_client = get_db_client() - block = get_block(node_exec.block_id) + block, cost, matching_filter = self._resolve_block_cost(node_exec) if not block: - logger.error(f"Block {node_exec.block_id} not found.") return total_cost, 0 - cost, matching_filter = block_usage_cost( - block=block, input_data=node_exec.inputs - ) if cost > 0: remaining_balance = db_client.spend_credits( user_id=node_exec.user_id, @@ -977,7 +1111,13 @@ class ExecutionProcessor: ) total_cost += cost - cost, usage_count = execution_usage_cost(execution_count) + # execution_count=0 is used by charge_node_usage for nested tool calls + # which must not be pushed into higher execution-count tiers. + # execution_usage_cost(0) would trigger a charge because 0 % threshold == 0, + # so skip it entirely when execution_count is 0. + cost, usage_count = ( + execution_usage_cost(execution_count) if execution_count > 0 else (0, 0) + ) if cost > 0: remaining_balance = db_client.spend_credits( user_id=node_exec.user_id, @@ -996,6 +1136,116 @@ class ExecutionProcessor: return total_cost, remaining_balance + # Hard cap on the multiplier passed to charge_extra_iterations to + # protect against a corrupted llm_call_count draining a user's balance. + # Real agent-mode runs are bounded by agent_mode_max_iterations (~50); + # 200 leaves headroom while preventing runaway charges. + _MAX_EXTRA_ITERATIONS = 200 + + def _charge_extra_iterations_sync( + self, + node_exec: NodeExecutionEntry, + capped_iterations: int, + ) -> tuple[int, int]: + """Synchronous implementation — runs in a thread-pool worker. + + Called only from :meth:`charge_extra_iterations`. Do not call + directly from async code. + + Note: ``_resolve_block_cost`` is called again here (rather than + reusing the result from ``_charge_usage`` at the start of execution) + because the two calls happen in separate thread-pool workers and + sharing mutable state across workers would require locks. The block + config is immutable during a run, so the repeated lookup is safe and + produces the same cost; the only overhead is an extra registry lookup. + """ + db_client = get_db_client() + block, cost, matching_filter = self._resolve_block_cost(node_exec) + if not block or cost <= 0: + return 0, 0 + total_extra_cost = cost * capped_iterations + remaining_balance = db_client.spend_credits( + user_id=node_exec.user_id, + cost=total_extra_cost, + metadata=UsageTransactionMetadata( + graph_exec_id=node_exec.graph_exec_id, + graph_id=node_exec.graph_id, + node_exec_id=node_exec.node_exec_id, + node_id=node_exec.node_id, + block_id=node_exec.block_id, + block=block.name, + input={ + **matching_filter, + "extra_iterations": capped_iterations, + }, + reason=( + f"Extra agent-mode iterations for {block.name} " + f"({capped_iterations} additional LLM calls)" + ), + ), + ) + return total_extra_cost, remaining_balance + + async def charge_extra_iterations( + self, + node_exec: NodeExecutionEntry, + extra_iterations: int, + ) -> tuple[int, int]: + """Charge a block extra iterations beyond the initial run. + + Used by agent-mode blocks (e.g. OrchestratorBlock) that make + multiple LLM calls within a single node execution. The first + iteration is already charged by :meth:`_charge_usage`; this + method charges *extra_iterations* additional copies of the + block's base cost. + + Returns ``(total_extra_cost, remaining_balance)``. May raise + ``InsufficientBalanceError`` if the user can't afford the charge. + """ + if extra_iterations <= 0: + return 0, 0 + # Cap to protect against a corrupted llm_call_count. + capped = min(extra_iterations, self._MAX_EXTRA_ITERATIONS) + return await asyncio.to_thread( + self._charge_extra_iterations_sync, node_exec, capped + ) + + def _charge_and_check_balance( + self, + node_exec: NodeExecutionEntry, + ) -> tuple[int, int]: + """Charge usage and check low balance in a single thread-pool worker. + + Combines ``_charge_usage`` and ``_handle_low_balance`` to avoid + dispatching two thread-pool calls per tool execution. + """ + total_cost, remaining = self._charge_usage(node_exec, 0) + if total_cost > 0: + self._handle_low_balance( + get_db_client(), node_exec.user_id, remaining, total_cost + ) + return total_cost, remaining + + async def charge_node_usage( + self, + node_exec: NodeExecutionEntry, + ) -> tuple[int, int]: + """Charge a single node execution to the user. + + Public async wrapper around :meth:`_charge_usage` for blocks (e.g. the + OrchestratorBlock) that spawn nested node executions outside the + main queue and therefore need to charge them explicitly. + + Also handles low-balance notification so callers don't need to touch + private methods directly. + + Note: this **does not** increment the global execution counter + (``increment_execution_count``). Nested tool executions are + sub-steps of a single block run from the user's perspective and + should not push them into higher per-execution cost tiers. + """ + return await asyncio.to_thread(self._charge_and_check_balance, node_exec) + @time_measured def _on_graph_execution( self,