fix(backend): Avoid executing any agent with zero balance (#9901)

### Changes 🏗️

* Avoid executing any agent with a zero balance.
* Make node execution count global across agents for a single user.

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Run agents by tweaking the `execution_cost_count_threshold` &
`execution_cost_per_threshold` values.
This commit is contained in:
Zamil Majdy
2025-05-01 22:11:38 +07:00
committed by GitHub
parent 86d5cfe60b
commit 475c5a5cc3
6 changed files with 57 additions and 22 deletions

View File

@@ -383,7 +383,7 @@ class UserCredit(UserCreditBase):
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await self.notification_client().queue_notification(
await self.notification_client().queue_notification_async(
NotificationEventDTO(
user_id=notification_request.user_id,
type=notification_type,
@@ -953,7 +953,7 @@ class BetaUserCredit(UserCredit):
class DisabledUserCredit(UserCreditBase):
async def get_credits(self, *args, **kwargs) -> int:
return 0
return 100
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
return TransactionHistory(transactions=[], next_transaction_time=None)

View File

@@ -56,6 +56,10 @@ async def _spend_credits(
return await _user_credit_model.spend_credits(user_id, cost, metadata)
async def _get_credits(user_id: str) -> int:
return await _user_credit_model.get_credits(user_id)
class DatabaseManager(AppService):
def run_service(self) -> None:
@@ -73,7 +77,11 @@ class DatabaseManager(AppService):
return config.database_api_port
@staticmethod
def _(f: Callable[P, R]) -> Callable[Concatenate[object, P], R]:
def _(
f: Callable[P, R], name: str | None = None
) -> Callable[Concatenate[object, P], R]:
if name is not None:
f.__name__ = name
return cast(Callable[Concatenate[object, P], R], expose(f))
# Executions
@@ -97,7 +105,8 @@ class DatabaseManager(AppService):
get_graph_metadata = _(get_graph_metadata)
# Credits
spend_credits = _(_spend_credits)
spend_credits = _(_spend_credits, name="spend_credits")
get_credits = _(_get_credits, name="get_credits")
# User + User Metadata + User Integrations
get_user_metadata = _(get_user_metadata)
@@ -153,6 +162,7 @@ class DatabaseManagerClient(AppServiceClient):
# Credits
spend_credits = _(d.spend_credits)
get_credits = _(d.get_credits)
# User + User Metadata + User Integrations
get_user_metadata = _(d.get_user_metadata)

View File

@@ -598,11 +598,11 @@ class Executor:
node_exec: NodeExecutionEntry,
execution_count: int,
execution_stats: GraphExecutionStats,
) -> int:
):
block = get_block(node_exec.block_id)
if not block:
logger.error(f"Block {node_exec.block_id} not found.")
return execution_count
return
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
if cost > 0:
@@ -622,7 +622,7 @@ class Executor:
)
execution_stats.cost += cost
cost, execution_count = execution_usage_cost(execution_count)
cost, usage_count = execution_usage_cost(execution_count)
if cost > 0:
cls.db_client.spend_credits(
user_id=node_exec.user_id,
@@ -631,16 +631,14 @@ class Executor:
graph_exec_id=node_exec.graph_exec_id,
graph_id=node_exec.graph_id,
input={
"execution_count": execution_count,
"execution_count": usage_count,
"charge": "Execution Cost",
},
reason=f"Execution Cost for ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
),
)
execution_stats.cost += cost
return execution_count
@classmethod
@time_measured
def _on_graph_execution(
@@ -677,11 +675,18 @@ class Executor:
cancel_thread.start()
try:
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
raise InsufficientBalanceError(
user_id=graph_exec.user_id,
message="You have no credits left to run an agent.",
balance=0,
amount=1,
)
queue = ExecutionQueue[NodeExecutionEntry]()
for node_exec in graph_exec.start_node_execs:
queue.add(node_exec)
exec_cost_counter = 0
running_executions: dict[str, AsyncResult] = {}
def make_exec_callback(exec_data: NodeExecutionEntry):
@@ -737,9 +742,9 @@ class Executor:
)
try:
exec_cost_counter = cls._charge_usage(
cls._charge_usage(
node_exec=queued_node_exec,
execution_count=exec_cost_counter + 1,
execution_count=increment_execution_count(graph_exec.user_id),
execution_stats=execution_stats,
)
except InsufficientBalanceError as error:
@@ -1097,6 +1102,19 @@ def synchronized(key: str, timeout: int = 60):
lock.release()
def increment_execution_count(user_id: str) -> int:
"""
Increment the execution count for a given user,
this will be used to charge the user for the execution cost.
"""
r = redis.get_redis()
k = f"uec:{user_id}" # User Execution Count global key
counter = cast(int, r.incr(k))
if counter == 1:
r.expire(k, settings.config.execution_counter_expiration_time)
return counter
def llprint(message: str):
"""
Low-level print/log helper function for use in signal handlers.

View File

@@ -93,19 +93,21 @@ def get_db_client() -> "DatabaseManagerClient":
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
"""
Calculate the cost of executing a graph based on the number of executions.
Calculate the cost of executing a graph based on the current number of node executions.
Args:
execution_count: Number of executions
execution_count: Number of node executions
Returns:
Tuple of cost amount and remaining execution count
Tuple of cost amount and the number of execution count that is included in the cost.
"""
return (
execution_count
// config.execution_cost_count_threshold
* config.execution_cost_per_threshold,
execution_count % config.execution_cost_count_threshold,
(
config.execution_cost_per_threshold
if execution_count % config.execution_cost_count_threshold == 0
else 0
),
config.execution_cost_count_threshold,
)

View File

@@ -787,6 +787,7 @@ class NotificationManagerClient(AppServiceClient):
def get_service_type(cls):
return NotificationManager
queue_notification = endpoint_to_async(NotificationManager.queue_notification)
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
queue_notification = NotificationManager.queue_notification
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary

View File

@@ -117,6 +117,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
default=1,
description="Cost per execution in cents after each threshold.",
)
execution_counter_expiration_time: int = Field(
default=60 * 60 * 24,
description="Time in seconds after which the execution counter is reset.",
)
model_config = SettingsConfigDict(
env_file=".env",