mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
fix(orchestrator): address review feedback on per-iteration cost charging
Addresses critical issues from autogpt-pr-reviewer + coderabbit review: Billing safety: - Surface InsufficientBalanceError as ERROR (not warning) so monitoring picks up billing leaks; other charge failures still log a warning. - Cap extra_iterations at MAX_EXTRA_ITERATIONS=200 to prevent a corrupted llm_call_count from draining a user's balance. - Tools now charged AFTER successful execution, not before — failed tools no longer cost credits, matching the rest of the platform. - charge_node_usage uses execution_count=0 so nested tool calls don't inflate the per-execution counter / push users into higher cost tiers. - charge_extra_iterations now returns (cost, remaining_balance) and the caller invokes _handle_low_balance to send low-balance notifications. Error handling consistency: - _execute_single_tool_with_manager re-raises InsufficientBalanceError instead of swallowing it into a tool-error response. This prevents leaking the user's exact balance to the LLM context and lets the outer error handling stop the run cleanly, mirroring the main queue. Test fixes: - test_orchestrator_per_iteration_cost.py: rewritten with pytest monkeypatch fixtures (no more manual save/restore), proper FakeBlock with .name attribute set correctly, plus new tests for the cap, block-not-found, InsufficientBalanceError propagation, and charge_node_usage delegation. - test_orchestrator.py / test_orchestrator_responses_api.py / test_orchestrator_dynamic_fields.py: mock charge_node_usage on the execution processor stub so existing agent-mode tests still pass with the new charging call. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -36,6 +36,7 @@ from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import NodeExecutionStats, SchemaField
|
||||
from backend.util import json
|
||||
from backend.util.clients import get_database_manager_async_client
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.prompt import MAIN_OBJECTIVE_PREFIX
|
||||
from backend.util.security import SENSITIVE_FIELD_NAMES
|
||||
from backend.util.tool_call_loop import (
|
||||
@@ -1107,18 +1108,6 @@ class OrchestratorBlock(Block):
|
||||
execution_processor.execution_stats_lock,
|
||||
)
|
||||
|
||||
# Charge user credits for the tool execution. Tools spawned by the
|
||||
# orchestrator bypass the main execution queue (where _charge_usage
|
||||
# is called), so we must charge here to avoid free tool execution.
|
||||
# Skipped for dry runs and when block has no cost configured.
|
||||
if not execution_params.execution_context.dry_run:
|
||||
tool_cost, _ = await asyncio.to_thread(
|
||||
execution_processor.charge_node_usage,
|
||||
node_exec_entry,
|
||||
)
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Create a completed future for the task tracking system
|
||||
node_exec_future = Future()
|
||||
node_exec_progress.add_task(
|
||||
@@ -1127,14 +1116,31 @@ class OrchestratorBlock(Block):
|
||||
)
|
||||
|
||||
# Execute the node directly since we're in the Orchestrator context
|
||||
node_exec_future.set_result(
|
||||
await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
tool_node_stats = await execution_processor.on_node_execution(
|
||||
node_exec=node_exec_entry,
|
||||
node_exec_progress=node_exec_progress,
|
||||
nodes_input_masks=None,
|
||||
graph_stats_pair=graph_stats_pair,
|
||||
)
|
||||
node_exec_future.set_result(tool_node_stats)
|
||||
|
||||
# Charge user credits AFTER successful tool execution. Tools
|
||||
# spawned by the orchestrator bypass the main execution queue
|
||||
# (where _charge_usage is called), so we must charge here to
|
||||
# avoid free tool execution. Charging post-completion (vs.
|
||||
# pre-execution) avoids billing users for failed tool calls.
|
||||
# Skipped for dry runs.
|
||||
if (
|
||||
not execution_params.execution_context.dry_run
|
||||
and tool_node_stats is not None
|
||||
and not isinstance(tool_node_stats.error, Exception)
|
||||
):
|
||||
tool_cost, _ = await asyncio.to_thread(
|
||||
execution_processor.charge_node_usage,
|
||||
node_exec_entry,
|
||||
)
|
||||
if tool_cost > 0:
|
||||
self.merge_stats(NodeExecutionStats(extra_cost=tool_cost))
|
||||
|
||||
# Get outputs from database after execution completes using database manager client
|
||||
node_outputs = await db_client.get_execution_outputs_by_node_exec_id(
|
||||
@@ -1151,6 +1157,12 @@ class OrchestratorBlock(Block):
|
||||
tool_call.id, tool_response_content, responses_api=responses_api
|
||||
)
|
||||
|
||||
except InsufficientBalanceError:
|
||||
# Don't downgrade billing failures into tool errors — let the
|
||||
# orchestrator's outer error handling stop the run cleanly,
|
||||
# mirroring the behaviour of the main execution queue. Also
|
||||
# prevents leaking exact balance amounts to the LLM context.
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning("Tool execution with manager failed: %s", e)
|
||||
# Return error response
|
||||
|
||||
@@ -922,6 +922,10 @@ async def test_orchestrator_agent_mode():
|
||||
mock_execution_processor.on_node_execution = AsyncMock(
|
||||
return_value=mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
# Returns (cost, remaining_balance). Synchronous because it's called
|
||||
# via asyncio.to_thread.
|
||||
mock_execution_processor.charge_node_usage = MagicMock(return_value=(0, 0))
|
||||
|
||||
# Mock the get_execution_outputs_by_node_exec_id method
|
||||
mock_db_client.get_execution_outputs_by_node_exec_id.return_value = {
|
||||
|
||||
@@ -638,6 +638,10 @@ async def test_validation_errors_dont_pollute_conversation():
|
||||
mock_execution_processor.on_node_execution.return_value = (
|
||||
mock_node_stats
|
||||
)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
mock_execution_processor.charge_node_usage = MagicMock(
|
||||
return_value=(0, 0)
|
||||
)
|
||||
|
||||
async for output_name, output_value in block.run(
|
||||
input_data,
|
||||
|
||||
@@ -8,9 +8,14 @@ the block completes.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
|
||||
|
||||
# ── Class flag ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChargePerLlmCallFlag:
|
||||
"""OrchestratorBlock opts into per-LLM-call billing."""
|
||||
|
||||
@@ -23,16 +28,117 @@ class TestChargePerLlmCallFlag:
|
||||
assert Block.charge_per_llm_call is False
|
||||
|
||||
|
||||
class TestChargeExtraIterations:
|
||||
"""The executor charges ``cost * (llm_call_count - 1)`` extra credits."""
|
||||
# ── charge_extra_iterations math ───────────────────────────────────
|
||||
|
||||
def _make_processor_with_block_cost(self, base_cost: int):
|
||||
"""Build a minimal ExecutionProcessor stub with a stubbed block lookup."""
|
||||
|
||||
@pytest.fixture()
|
||||
def fake_node_exec():
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
return node_exec
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def patched_processor(monkeypatch):
|
||||
"""ExecutionProcessor with stubbed db client / block lookup helpers.
|
||||
|
||||
Returns the processor and a list of credit amounts spent so tests can
|
||||
assert on what was charged.
|
||||
"""
|
||||
from backend.executor import manager
|
||||
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 1000 # remaining balance
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"block_usage_cost",
|
||||
lambda block, input_data, **_kw: (10, {"model": "claude-sonnet-4-6"}),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
return proc, spent
|
||||
|
||||
|
||||
class TestChargeExtraIterations:
|
||||
def test_zero_extra_iterations_charges_nothing(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = proc.charge_extra_iterations(fake_node_exec, extra_iterations=0)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
def test_extra_iterations_multiplies_base_cost(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = proc.charge_extra_iterations(fake_node_exec, extra_iterations=4)
|
||||
assert cost == 40 # 4 × 10
|
||||
assert balance == 1000
|
||||
assert spent == [40]
|
||||
|
||||
def test_negative_extra_iterations_charges_nothing(
|
||||
self, patched_processor, fake_node_exec
|
||||
):
|
||||
proc, spent = patched_processor
|
||||
cost, balance = proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=-1
|
||||
)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
def test_capped_at_max(self, monkeypatch, fake_node_exec):
|
||||
"""Runaway llm_call_count is capped at _MAX_EXTRA_ITERATIONS."""
|
||||
from backend.executor import manager
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 1000
|
||||
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"block_usage_cost",
|
||||
lambda block, input_data, **_kw: (10, {}),
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cap = manager.ExecutionProcessor._MAX_EXTRA_ITERATIONS
|
||||
cost, _ = proc.charge_extra_iterations(
|
||||
fake_node_exec, extra_iterations=cap * 100
|
||||
)
|
||||
# Charged at most cap × 10
|
||||
assert cost == cap * 10
|
||||
assert spent == [cap * 10]
|
||||
|
||||
def test_zero_base_cost_skips_charge(self, monkeypatch, fake_node_exec):
|
||||
from backend.executor import manager
|
||||
|
||||
# Stub the spend_credits client and block_usage_cost helper.
|
||||
spent: list[int] = []
|
||||
|
||||
class FakeDb:
|
||||
@@ -40,95 +146,94 @@ class TestChargeExtraIterations:
|
||||
spent.append(cost)
|
||||
return 0
|
||||
|
||||
# Patch get_db_client and get_block + block_usage_cost on the manager
|
||||
# module so charge_extra_iterations sees deterministic values.
|
||||
original_get_db = manager.get_db_client
|
||||
original_get_block = manager.get_block
|
||||
original_block_usage_cost = manager.block_usage_cost
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
def restore():
|
||||
manager.get_db_client = original_get_db
|
||||
manager.get_block = original_get_block
|
||||
manager.block_usage_cost = original_block_usage_cost
|
||||
|
||||
manager.get_db_client = lambda: FakeDb()
|
||||
manager.get_block = lambda block_id: MagicMock(name="block")
|
||||
manager.block_usage_cost = lambda block, input_data: (
|
||||
base_cost,
|
||||
{"model": "claude-sonnet-4-6"},
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (0, {})
|
||||
)
|
||||
|
||||
return proc, spent, restore
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = proc.charge_extra_iterations(fake_node_exec, extra_iterations=4)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
def test_zero_extra_iterations_charges_nothing(self):
|
||||
proc, spent, restore = self._make_processor_with_block_cost(base_cost=10)
|
||||
try:
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
def test_block_not_found_skips_charge(self, monkeypatch, fake_node_exec):
|
||||
from backend.executor import manager
|
||||
|
||||
charged = proc.charge_extra_iterations(node_exec, extra_iterations=0)
|
||||
assert charged == 0
|
||||
assert spent == []
|
||||
finally:
|
||||
restore()
|
||||
spent: list[int] = []
|
||||
|
||||
def test_extra_iterations_multiplies_base_cost(self):
|
||||
proc, spent, restore = self._make_processor_with_block_cost(base_cost=10)
|
||||
try:
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
spent.append(cost)
|
||||
return 0
|
||||
|
||||
charged = proc.charge_extra_iterations(node_exec, extra_iterations=4)
|
||||
# 4 extra iterations × 10 base_cost = 40
|
||||
assert charged == 40
|
||||
assert spent == [40]
|
||||
finally:
|
||||
restore()
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: None)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
|
||||
)
|
||||
|
||||
def test_zero_base_cost_skips_charge(self):
|
||||
proc, spent, restore = self._make_processor_with_block_cost(base_cost=0)
|
||||
try:
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = proc.charge_extra_iterations(fake_node_exec, extra_iterations=3)
|
||||
assert cost == 0
|
||||
assert balance == 0
|
||||
assert spent == []
|
||||
|
||||
charged = proc.charge_extra_iterations(node_exec, extra_iterations=4)
|
||||
assert charged == 0
|
||||
assert spent == []
|
||||
finally:
|
||||
restore()
|
||||
def test_propagates_insufficient_balance_error(self, monkeypatch, fake_node_exec):
|
||||
"""Out-of-credits errors must propagate, not be silently swallowed."""
|
||||
from backend.executor import manager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
def test_negative_extra_iterations_charges_nothing(self):
|
||||
proc, spent, restore = self._make_processor_with_block_cost(base_cost=10)
|
||||
try:
|
||||
node_exec = MagicMock()
|
||||
node_exec.user_id = "u"
|
||||
node_exec.graph_exec_id = "g"
|
||||
node_exec.graph_id = "g"
|
||||
node_exec.node_exec_id = "ne"
|
||||
node_exec.node_id = "n"
|
||||
node_exec.block_id = "b"
|
||||
node_exec.inputs = {}
|
||||
class FakeDb:
|
||||
def spend_credits(self, *, user_id, cost, metadata):
|
||||
raise InsufficientBalanceError(
|
||||
user_id=user_id,
|
||||
message="Insufficient balance",
|
||||
balance=0,
|
||||
amount=cost,
|
||||
)
|
||||
|
||||
charged = proc.charge_extra_iterations(node_exec, extra_iterations=-1)
|
||||
assert charged == 0
|
||||
assert spent == []
|
||||
finally:
|
||||
restore()
|
||||
fake_block = MagicMock()
|
||||
fake_block.name = "FakeBlock"
|
||||
|
||||
monkeypatch.setattr(manager, "get_db_client", lambda: FakeDb())
|
||||
monkeypatch.setattr(manager, "get_block", lambda block_id: fake_block)
|
||||
monkeypatch.setattr(
|
||||
manager, "block_usage_cost", lambda block, input_data, **_kw: (10, {})
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
with pytest.raises(InsufficientBalanceError):
|
||||
proc.charge_extra_iterations(fake_node_exec, extra_iterations=4)
|
||||
|
||||
|
||||
# ── charge_node_usage ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestChargeNodeUsage:
|
||||
"""charge_node_usage delegates to _charge_usage with execution_count=0."""
|
||||
|
||||
def test_delegates_with_zero_execution_count(self, monkeypatch, fake_node_exec):
|
||||
"""Nested tool charges should NOT inflate the per-execution counter."""
|
||||
from backend.executor import manager
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def fake_charge_usage(self, node_exec, execution_count):
|
||||
captured["execution_count"] = execution_count
|
||||
captured["node_exec"] = node_exec
|
||||
return (5, 100)
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager.ExecutionProcessor, "_charge_usage", fake_charge_usage
|
||||
)
|
||||
|
||||
proc = manager.ExecutionProcessor.__new__(manager.ExecutionProcessor)
|
||||
cost, balance = proc.charge_node_usage(fake_node_exec)
|
||||
assert cost == 5
|
||||
assert balance == 100
|
||||
assert captured["execution_count"] == 0
|
||||
|
||||
@@ -956,6 +956,8 @@ async def test_agent_mode_conversation_valid_for_responses_api():
|
||||
ep.execution_stats_lock = threading.Lock()
|
||||
ns = MagicMock(error=None)
|
||||
ep.on_node_execution = AsyncMock(return_value=ns)
|
||||
# Mock charge_node_usage (called after successful tool execution).
|
||||
ep.charge_node_usage = MagicMock(return_value=(0, 0))
|
||||
|
||||
with patch("backend.blocks.llm.llm_call", llm_mock), patch.object(
|
||||
block, "_create_tool_node_signatures", return_value=tool_sigs
|
||||
|
||||
@@ -683,6 +683,11 @@ class ExecutionProcessor:
|
||||
# billing (e.g. OrchestratorBlock in agent mode). The first call
|
||||
# is already covered by _charge_usage(); each additional LLM call
|
||||
# costs another base_cost. Skipped for dry runs and failed runs.
|
||||
#
|
||||
# InsufficientBalanceError is logged at ERROR level (this is a
|
||||
# billing leak — the work is already done, but the user can't pay)
|
||||
# and re-surfaced via execution_stats.error so monitoring can pick
|
||||
# it up. Other exceptions are warnings.
|
||||
if (
|
||||
status == ExecutionStatus.COMPLETED
|
||||
and node.block.charge_per_llm_call
|
||||
@@ -691,16 +696,27 @@ class ExecutionProcessor:
|
||||
):
|
||||
extra_iterations = execution_stats.llm_call_count - 1
|
||||
try:
|
||||
extra_cost = await asyncio.to_thread(
|
||||
extra_cost, remaining_balance = await asyncio.to_thread(
|
||||
self.charge_extra_iterations,
|
||||
node_exec,
|
||||
extra_iterations,
|
||||
)
|
||||
if extra_cost > 0:
|
||||
execution_stats.extra_cost += extra_cost
|
||||
self._handle_low_balance(
|
||||
db_client=get_db_client(),
|
||||
user_id=node_exec.user_id,
|
||||
current_balance=remaining_balance,
|
||||
transaction_cost=extra_cost,
|
||||
)
|
||||
except InsufficientBalanceError as e:
|
||||
log_metadata.error(
|
||||
f"Billing leak: insufficient balance after {node.block.name} "
|
||||
f"completed {extra_iterations} extra iterations: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
log_metadata.warning(
|
||||
f"Failed to charge extra iterations for " f"{node.block.name}: {e}"
|
||||
f"Failed to charge extra iterations for {node.block.name}: {e}"
|
||||
)
|
||||
|
||||
graph_stats, graph_stats_lock = graph_stats_pair
|
||||
@@ -1018,6 +1034,12 @@ class ExecutionProcessor:
|
||||
|
||||
return total_cost, remaining_balance
|
||||
|
||||
# Hard cap on the multiplier passed to charge_extra_iterations to
|
||||
# protect against a corrupted llm_call_count draining a user's balance.
|
||||
# Real agent-mode runs are bounded by agent_mode_max_iterations (~50);
|
||||
# 200 leaves headroom while preventing runaway charges.
|
||||
_MAX_EXTRA_ITERATIONS = 200
|
||||
|
||||
def charge_node_usage(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
@@ -1027,17 +1049,22 @@ class ExecutionProcessor:
|
||||
Public wrapper around :meth:`_charge_usage` for blocks (e.g. the
|
||||
OrchestratorBlock) that spawn nested node executions outside the
|
||||
main queue and therefore need to charge them explicitly.
|
||||
|
||||
Note: this **does not** increment the global execution counter
|
||||
(``increment_execution_count``). Nested tool executions are
|
||||
sub-steps of a single block run from the user's perspective and
|
||||
should not push them into higher per-execution cost tiers.
|
||||
"""
|
||||
return self._charge_usage(
|
||||
node_exec=node_exec,
|
||||
execution_count=increment_execution_count(node_exec.user_id),
|
||||
execution_count=0,
|
||||
)
|
||||
|
||||
def charge_extra_iterations(
|
||||
self,
|
||||
node_exec: NodeExecutionEntry,
|
||||
extra_iterations: int,
|
||||
) -> int:
|
||||
) -> tuple[int, int]:
|
||||
"""Charge a block extra iterations beyond the initial run.
|
||||
|
||||
Used by agent-mode blocks (e.g. OrchestratorBlock) that make
|
||||
@@ -1045,20 +1072,25 @@ class ExecutionProcessor:
|
||||
iteration is already charged by :meth:`_charge_usage`; this
|
||||
method charges *extra_iterations* additional copies of the
|
||||
block's base cost.
|
||||
|
||||
Returns ``(total_extra_cost, remaining_balance)``. May raise
|
||||
``InsufficientBalanceError`` if the user can't afford the charge.
|
||||
"""
|
||||
if extra_iterations <= 0:
|
||||
return 0
|
||||
return 0, 0
|
||||
# Cap to protect against a corrupted llm_call_count.
|
||||
capped_iterations = min(extra_iterations, self._MAX_EXTRA_ITERATIONS)
|
||||
db_client = get_db_client()
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
return 0
|
||||
return 0, 0
|
||||
cost, matching_filter = block_usage_cost(
|
||||
block=block, input_data=node_exec.inputs
|
||||
)
|
||||
if cost <= 0:
|
||||
return 0
|
||||
total_extra_cost = cost * extra_iterations
|
||||
db_client.spend_credits(
|
||||
return 0, 0
|
||||
total_extra_cost = cost * capped_iterations
|
||||
remaining_balance = db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
cost=total_extra_cost,
|
||||
metadata=UsageTransactionMetadata(
|
||||
@@ -1070,15 +1102,15 @@ class ExecutionProcessor:
|
||||
block=block.name,
|
||||
input={
|
||||
**matching_filter,
|
||||
"extra_iterations": extra_iterations,
|
||||
"extra_iterations": capped_iterations,
|
||||
},
|
||||
reason=(
|
||||
f"Extra agent-mode iterations for {block.name} "
|
||||
f"({extra_iterations} additional LLM calls)"
|
||||
f"({capped_iterations} additional LLM calls)"
|
||||
),
|
||||
),
|
||||
)
|
||||
return total_extra_cost
|
||||
return total_extra_cost, remaining_balance
|
||||
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
|
||||
Reference in New Issue
Block a user