feat(backend): Add capability to charge based on block execution count (#9661)

Blocks that are not defined in the block cost are pretty much free. The
lack of cost control makes it hard to control its quota. The scope of
this change is providing a way to charge any executions based on the
number of block being executed in real-time.

### Changes 🏗️

* Add execution charge logic based on the number of blocks executed,
controlled by these two configurations:
* `execution_cost_count_threshold`: We will charge the execution based
on the multiple of this number.
* `execution_cost_per_threshold`: The amount we are charging on its
threshold multiple.
* Make charging logic on the graph execution logic (as opposed to node
level) so it's being done serially and insufficient fund error is
guaranteed to stop the graph execution.
* Moved cost calculation logic into backend/executor/util.py

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
- [x] Execute graph with configured threshold & cost and test the
balance being deducted on that.
  - [x] Existing cost calculation is still being done without any issue.
  - [x] Low balance stop the whole graph execution.
This commit is contained in:
Zamil Majdy
2025-03-24 14:26:33 +07:00
committed by GitHub
parent 5b118fc939
commit 26984a7338
7 changed files with 241 additions and 136 deletions

View File

@@ -15,14 +15,11 @@ from prisma.enums import (
from prisma.errors import UniqueViolationError
from prisma.models import CreditRefundRequest, CreditTransaction, User
from prisma.types import CreditTransactionCreateInput, CreditTransactionWhereInput
from pydantic import BaseModel
from tenacity import retry, stop_after_attempt, wait_exponential
from backend.data import db
from backend.data.block import Block, BlockInput, get_block
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCost, BlockCostType
from backend.data.execution import NodeExecutionEntry
from backend.data.cost import BlockCost
from backend.data.model import (
AutoTopUpConfig,
RefundRequest,
@@ -31,6 +28,7 @@ from backend.data.model import (
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.executor.utils import UsageTransactionMetadata
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
@@ -91,20 +89,20 @@ class UserCreditBase(ABC):
@abstractmethod
async def spend_credits(
self,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
) -> int:
"""
Spend the credits for the user based on the block usage.
Spend the credits for the user based on the cost.
Args:
entry (NodeExecutionEntry): The node execution identifiers & data.
data_size (float): The size of the data being processed.
run_time (float): The time taken to run the block.
user_id (str): The user ID.
cost (int): The cost to spend.
metadata (UsageTransactionMetadata): The metadata of the transaction.
Returns:
int: amount of credit spent
int: The remaining balance.
"""
pass
@@ -348,16 +346,6 @@ class UserCreditBase(ABC):
return user_balance + amount, tx.transactionKey
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManager:
@@ -378,89 +366,21 @@ class UserCredit(UserCreditBase):
)
)
def _block_usage_cost(
self,
block: Block,
input_data: BlockInput,
data_size: float,
run_time: float,
) -> tuple[int, BlockInput]:
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not self._is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(
self, cost_filter: BlockInput, input_data: BlockInput
) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and self._is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)
async def spend_credits(
self,
entry: NodeExecutionEntry,
data_size: float,
run_time: float,
user_id: str,
cost: int,
metadata: UsageTransactionMetadata,
) -> int:
block = get_block(entry.block_id)
if not block:
raise ValueError(f"Block not found: {entry.block_id}")
cost, matching_filter = self._block_usage_cost(
block=block, input_data=entry.data, data_size=data_size, run_time=run_time
)
if cost == 0:
return 0
balance, _ = await self._add_transaction(
user_id=entry.user_id,
user_id=user_id,
amount=-cost,
transaction_type=CreditTransactionType.USAGE,
metadata=Json(
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=block.name,
input=matching_filter,
).model_dump()
),
metadata=Json(metadata.model_dump()),
)
user_id = entry.user_id
# Auto top-up if balance is below threshold.
auto_top_up = await get_auto_top_up(user_id)
@@ -470,7 +390,7 @@ class UserCredit(UserCreditBase):
user_id=user_id,
amount=auto_top_up.amount,
# Avoid multiple auto top-ups within the same graph execution.
key=f"AUTO-TOP-UP-{user_id}-{entry.graph_exec_id}",
key=f"AUTO-TOP-UP-{user_id}-{metadata.graph_exec_id}",
ceiling_balance=auto_top_up.threshold,
)
except Exception as e:
@@ -479,7 +399,7 @@ class UserCredit(UserCreditBase):
f"Auto top-up failed for user {user_id}, balance: {balance}, amount: {auto_top_up.amount}, error: {e}"
)
return cost
return balance
async def top_up_credits(self, user_id: str, amount: int):
await self._top_up_credits(user_id, amount)

View File

@@ -415,7 +415,6 @@ class NodeExecutionStats(BaseModel):
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
cost: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0

View File

@@ -1,7 +1,6 @@
from backend.data.credit import get_user_credit_model
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import (
ExecutionResult,
NodeExecutionEntry,
RedisExecutionEventBus,
create_graph_execution,
get_execution_results,
@@ -45,8 +44,10 @@ config = Config()
_user_credit_model = get_user_credit_model()
async def _spend_credits(entry: NodeExecutionEntry) -> int:
return await _user_credit_model.spend_credits(entry, 0, 0)
async def _spend_credits(
user_id: str, cost: int, metadata: UsageTransactionMetadata
) -> int:
return await _user_credit_model.spend_credits(user_id, cost, metadata)
class DatabaseManager(AppService):

View File

@@ -48,6 +48,11 @@ from backend.data.execution import (
parse_execution_output,
)
from backend.data.graph import GraphModel, Link, Node
from backend.executor.utils import (
UsageTransactionMetadata,
block_usage_cost,
execution_usage_cost,
)
from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util import json
from backend.util.decorator import error_logged, time_measured
@@ -206,11 +211,7 @@ def execute_node(
extra_exec_kwargs[field_name] = credentials
output_size = 0
cost = 0
try:
# Charge the user for the execution before running the block.
cost = db_client.spend_credits(data)
outputs: dict[str, Any] = {}
for output_name, output_data in node_block.execute(
input_data, **extra_exec_kwargs
@@ -266,7 +267,6 @@ def execute_node(
)
execution_stats.input_size = input_size
execution_stats.output_size = output_size
execution_stats.cost = cost
def _enqueue_next_nodes(
@@ -645,6 +645,53 @@ class Executor:
cls._handle_agent_run_notif(graph_exec, exec_stats)
@classmethod
def _charge_usage(
cls,
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
) -> int:
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return execution_count
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
node_exec_id=node_exec.node_exec_id,
node_id=node_exec.node_id,
block_id=node_exec.block_id,
block=block.name,
input=matching_filter,
),
)
execution_stats.cost += cost
cost, execution_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
cost=cost,
metadata=UsageTransactionMetadata(
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": execution_count,
"charge": "Execution Cost",
},
),
)
execution_stats.cost += cost
return execution_count
@classmethod
@time_measured
def _on_graph_execution(
@@ -681,8 +728,8 @@ class Executor:
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
low_balance_error: Optional[InsufficientBalanceError] = None
def make_exec_callback(exec_data: NodeExecutionEntry):
@@ -692,17 +739,13 @@ class Executor:
if not isinstance(result, NodeExecutionStats):
return
nonlocal exec_stats, low_balance_error
nonlocal exec_stats
exec_stats.node_count += 1
exec_stats.nodes_cputime += result.cputime
exec_stats.nodes_walltime += result.walltime
exec_stats.cost += result.cost
if (err := result.error) and isinstance(err, Exception):
exec_stats.node_error_count += 1
if isinstance(err, InsufficientBalanceError):
low_balance_error = err
return callback
while not queue.empty():
@@ -724,6 +767,30 @@ class Executor:
f"Dispatching node execution {exec_data.node_exec_id} "
f"for node {exec_data.node_id}",
)
try:
exec_cost_counter = cls._charge_usage(
node_exec=exec_data,
execution_count=exec_cost_counter + 1,
execution_stats=exec_stats,
)
except InsufficientBalanceError as error:
exec_id = exec_data.node_exec_id
cls.db_client.upsert_execution_output(exec_id, "error", str(error))
exec_update = cls.db_client.update_execution_status(
exec_id, ExecutionStatus.FAILED
)
cls.db_client.send_execution_update(exec_update)
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
error,
)
raise
running_executions[exec_data.node_id] = cls.executor.apply_async(
cls.on_node_execution,
(queue, exec_data),
@@ -747,32 +814,24 @@ class Executor:
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
if isinstance(low_balance_error, InsufficientBalanceError):
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
low_balance_error,
)
raise low_balance_error
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
)
error = e
finally:
if error:
log_metadata.error(
f"Failed graph execution {graph_exec.graph_exec_id}: {error}"
)
execution_status = ExecutionStatus.FAILED
else:
execution_status = ExecutionStatus.COMPLETED
if not cancel.is_set():
finished = True
cancel.set()
cancel_thread.join()
clean_exec_files(graph_exec.graph_exec_id)
return (
exec_stats,
ExecutionStatus.FAILED if error else ExecutionStatus.COMPLETED,
error,
)
return exec_stats, execution_status, error
@classmethod
def _handle_agent_run_notif(

View File

@@ -0,0 +1,97 @@
from pydantic import BaseModel
from backend.data.block import Block, BlockInput
from backend.data.block_cost_config import BLOCK_COSTS
from backend.data.cost import BlockCostType
from backend.util.settings import Config
config = Config()
class UsageTransactionMetadata(BaseModel):
graph_exec_id: str | None = None
graph_id: str | None = None
node_id: str | None = None
node_exec_id: str | None = None
block_id: str | None = None
block: str | None = None
input: BlockInput | None = None
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Args:
execution_count: Number of executions
Returns:
Tuple of cost amount and remaining execution count
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
)
def block_usage_cost(
block: Block,
input_data: BlockInput,
data_size: float = 0,
run_time: float = 0,
) -> tuple[int, BlockInput]:
"""
Calculate the cost of using a block based on the input data and the block type.
Args:
block: Block object
input_data: Input data for the block
data_size: Size of the input data in bytes
run_time: Execution time of the block in seconds
Returns:
Tuple of cost amount and cost filter
"""
block_costs = BLOCK_COSTS.get(type(block))
if not block_costs:
return 0, {}
for block_cost in block_costs:
if not _is_cost_filter_match(block_cost.cost_filter, input_data):
continue
if block_cost.cost_type == BlockCostType.RUN:
return block_cost.cost_amount, block_cost.cost_filter
if block_cost.cost_type == BlockCostType.SECOND:
return (
int(run_time * block_cost.cost_amount),
block_cost.cost_filter,
)
if block_cost.cost_type == BlockCostType.BYTE:
return (
int(data_size * block_cost.cost_amount),
block_cost.cost_filter,
)
return 0, {}
def _is_cost_filter_match(cost_filter: BlockInput, input_data: BlockInput) -> bool:
"""
Filter rules:
- If cost_filter is an object, then check if cost_filter is the subset of input_data
- Otherwise, check if cost_filter is equal to input_data.
- Undefined, null, and empty string are considered as equal.
"""
if not isinstance(cost_filter, dict) or not isinstance(input_data, dict):
return cost_filter == input_data
return all(
(not input_data.get(k) and not v)
or (input_data.get(k) and _is_cost_filter_match(v, input_data[k]))
for k, v in cost_filter.items()
)

View File

@@ -113,6 +113,14 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default="%Y-%W", # This will allow for weekly refunds per user.
description="Time key format for refund requests.",
)
execution_cost_count_threshold: int = Field(
default=100,
description="Number of executions after which the cost is calculated.",
)
execution_cost_per_threshold: int = Field(
default=1,
description="Cost per execution in cents after each threshold.",
)
model_config = SettingsConfigDict(
env_file=".env",

View File

@@ -5,9 +5,11 @@ from prisma.enums import CreditTransactionType
from prisma.models import CreditTransaction
from backend.blocks.llm import AITextGeneratorBlock
from backend.data.block import get_block
from backend.data.credit import BetaUserCredit
from backend.data.execution import NodeExecutionEntry
from backend.data.user import DEFAULT_USER_ID
from backend.executor.utils import UsageTransactionMetadata, block_usage_cost
from backend.integrations.credentials_store import openai_credentials
from backend.util.test import SpinTestServer
@@ -27,13 +29,36 @@ async def top_up(amount: int):
)
async def spend_credits(entry: NodeExecutionEntry) -> int:
block = get_block(entry.block_id)
if not block:
raise RuntimeError(f"Block {entry.block_id} not found")
cost, matching_filter = block_usage_cost(block=block, input_data=entry.data)
await user_credit.spend_credits(
entry.user_id,
cost,
UsageTransactionMetadata(
graph_exec_id=entry.graph_exec_id,
graph_id=entry.graph_id,
node_id=entry.node_id,
node_exec_id=entry.node_exec_id,
block_id=entry.block_id,
block=entry.block_id,
input=matching_filter,
),
)
return cost
@pytest.mark.asyncio(scope="session")
async def test_block_credit_usage(server: SpinTestServer):
await disable_test_user_transactions()
await top_up(100)
current_credit = await user_credit.get_credits(DEFAULT_USER_ID)
spending_amount_1 = await user_credit.spend_credits(
spending_amount_1 = await spend_credits(
NodeExecutionEntry(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
@@ -50,12 +75,10 @@ async def test_block_credit_usage(server: SpinTestServer):
},
},
),
0.0,
0.0,
)
assert spending_amount_1 > 0
spending_amount_2 = await user_credit.spend_credits(
spending_amount_2 = await spend_credits(
NodeExecutionEntry(
user_id=DEFAULT_USER_ID,
graph_id="test_graph",
@@ -65,8 +88,6 @@ async def test_block_credit_usage(server: SpinTestServer):
block_id=AITextGeneratorBlock().id,
data={"model": "gpt-4-turbo", "api_key": "owned_api_key"},
),
0.0,
0.0,
)
assert spending_amount_2 == 0