mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(backend): Avoid executing any agent with zero balance (#9901)
### Changes 🏗️ * Avoid executing any agent with a zero balance. * Make node execution count global across agents for a single user. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Run agents by tweaking the `execution_cost_count_threshold` & `execution_cost_per_threshold` values.
This commit is contained in:
@@ -383,7 +383,7 @@ class UserCredit(UserCreditBase):
|
||||
notification_request: RefundRequestData,
|
||||
notification_type: NotificationType,
|
||||
):
|
||||
await self.notification_client().queue_notification(
|
||||
await self.notification_client().queue_notification_async(
|
||||
NotificationEventDTO(
|
||||
user_id=notification_request.user_id,
|
||||
type=notification_type,
|
||||
@@ -953,7 +953,7 @@ class BetaUserCredit(UserCredit):
|
||||
|
||||
class DisabledUserCredit(UserCreditBase):
|
||||
async def get_credits(self, *args, **kwargs) -> int:
|
||||
return 0
|
||||
return 100
|
||||
|
||||
async def get_transaction_history(self, *args, **kwargs) -> TransactionHistory:
|
||||
return TransactionHistory(transactions=[], next_transaction_time=None)
|
||||
|
||||
@@ -56,6 +56,10 @@ async def _spend_credits(
|
||||
return await _user_credit_model.spend_credits(user_id, cost, metadata)
|
||||
|
||||
|
||||
async def _get_credits(user_id: str) -> int:
|
||||
return await _user_credit_model.get_credits(user_id)
|
||||
|
||||
|
||||
class DatabaseManager(AppService):
|
||||
|
||||
def run_service(self) -> None:
|
||||
@@ -73,7 +77,11 @@ class DatabaseManager(AppService):
|
||||
return config.database_api_port
|
||||
|
||||
@staticmethod
|
||||
def _(f: Callable[P, R]) -> Callable[Concatenate[object, P], R]:
|
||||
def _(
|
||||
f: Callable[P, R], name: str | None = None
|
||||
) -> Callable[Concatenate[object, P], R]:
|
||||
if name is not None:
|
||||
f.__name__ = name
|
||||
return cast(Callable[Concatenate[object, P], R], expose(f))
|
||||
|
||||
# Executions
|
||||
@@ -97,7 +105,8 @@ class DatabaseManager(AppService):
|
||||
get_graph_metadata = _(get_graph_metadata)
|
||||
|
||||
# Credits
|
||||
spend_credits = _(_spend_credits)
|
||||
spend_credits = _(_spend_credits, name="spend_credits")
|
||||
get_credits = _(_get_credits, name="get_credits")
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(get_user_metadata)
|
||||
@@ -153,6 +162,7 @@ class DatabaseManagerClient(AppServiceClient):
|
||||
|
||||
# Credits
|
||||
spend_credits = _(d.spend_credits)
|
||||
get_credits = _(d.get_credits)
|
||||
|
||||
# User + User Metadata + User Integrations
|
||||
get_user_metadata = _(d.get_user_metadata)
|
||||
|
||||
@@ -598,11 +598,11 @@ class Executor:
|
||||
node_exec: NodeExecutionEntry,
|
||||
execution_count: int,
|
||||
execution_stats: GraphExecutionStats,
|
||||
) -> int:
|
||||
):
|
||||
block = get_block(node_exec.block_id)
|
||||
if not block:
|
||||
logger.error(f"Block {node_exec.block_id} not found.")
|
||||
return execution_count
|
||||
return
|
||||
|
||||
cost, matching_filter = block_usage_cost(block=block, input_data=node_exec.data)
|
||||
if cost > 0:
|
||||
@@ -622,7 +622,7 @@ class Executor:
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
cost, execution_count = execution_usage_cost(execution_count)
|
||||
cost, usage_count = execution_usage_cost(execution_count)
|
||||
if cost > 0:
|
||||
cls.db_client.spend_credits(
|
||||
user_id=node_exec.user_id,
|
||||
@@ -631,16 +631,14 @@ class Executor:
|
||||
graph_exec_id=node_exec.graph_exec_id,
|
||||
graph_id=node_exec.graph_id,
|
||||
input={
|
||||
"execution_count": execution_count,
|
||||
"execution_count": usage_count,
|
||||
"charge": "Execution Cost",
|
||||
},
|
||||
reason=f"Execution Cost for ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
reason=f"Execution Cost for {usage_count} blocks of ex_id:{node_exec.graph_exec_id} g_id:{node_exec.graph_id}",
|
||||
),
|
||||
)
|
||||
execution_stats.cost += cost
|
||||
|
||||
return execution_count
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
def _on_graph_execution(
|
||||
@@ -677,11 +675,18 @@ class Executor:
|
||||
cancel_thread.start()
|
||||
|
||||
try:
|
||||
if cls.db_client.get_credits(graph_exec.user_id) <= 0:
|
||||
raise InsufficientBalanceError(
|
||||
user_id=graph_exec.user_id,
|
||||
message="You have no credits left to run an agent.",
|
||||
balance=0,
|
||||
amount=1,
|
||||
)
|
||||
|
||||
queue = ExecutionQueue[NodeExecutionEntry]()
|
||||
for node_exec in graph_exec.start_node_execs:
|
||||
queue.add(node_exec)
|
||||
|
||||
exec_cost_counter = 0
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
@@ -737,9 +742,9 @@ class Executor:
|
||||
)
|
||||
|
||||
try:
|
||||
exec_cost_counter = cls._charge_usage(
|
||||
cls._charge_usage(
|
||||
node_exec=queued_node_exec,
|
||||
execution_count=exec_cost_counter + 1,
|
||||
execution_count=increment_execution_count(graph_exec.user_id),
|
||||
execution_stats=execution_stats,
|
||||
)
|
||||
except InsufficientBalanceError as error:
|
||||
@@ -1097,6 +1102,19 @@ def synchronized(key: str, timeout: int = 60):
|
||||
lock.release()
|
||||
|
||||
|
||||
def increment_execution_count(user_id: str) -> int:
|
||||
"""
|
||||
Increment the execution count for a given user,
|
||||
this will be used to charge the user for the execution cost.
|
||||
"""
|
||||
r = redis.get_redis()
|
||||
k = f"uec:{user_id}" # User Execution Count global key
|
||||
counter = cast(int, r.incr(k))
|
||||
if counter == 1:
|
||||
r.expire(k, settings.config.execution_counter_expiration_time)
|
||||
return counter
|
||||
|
||||
|
||||
def llprint(message: str):
|
||||
"""
|
||||
Low-level print/log helper function for use in signal handlers.
|
||||
|
||||
@@ -93,19 +93,21 @@ def get_db_client() -> "DatabaseManagerClient":
|
||||
|
||||
def execution_usage_cost(execution_count: int) -> tuple[int, int]:
|
||||
"""
|
||||
Calculate the cost of executing a graph based on the number of executions.
|
||||
Calculate the cost of executing a graph based on the current number of node executions.
|
||||
|
||||
Args:
|
||||
execution_count: Number of executions
|
||||
execution_count: Number of node executions
|
||||
|
||||
Returns:
|
||||
Tuple of cost amount and remaining execution count
|
||||
Tuple of cost amount and the number of execution count that is included in the cost.
|
||||
"""
|
||||
return (
|
||||
execution_count
|
||||
// config.execution_cost_count_threshold
|
||||
* config.execution_cost_per_threshold,
|
||||
execution_count % config.execution_cost_count_threshold,
|
||||
(
|
||||
config.execution_cost_per_threshold
|
||||
if execution_count % config.execution_cost_count_threshold == 0
|
||||
else 0
|
||||
),
|
||||
config.execution_cost_count_threshold,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -787,6 +787,7 @@ class NotificationManagerClient(AppServiceClient):
|
||||
def get_service_type(cls):
|
||||
return NotificationManager
|
||||
|
||||
queue_notification = endpoint_to_async(NotificationManager.queue_notification)
|
||||
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
|
||||
queue_notification = NotificationManager.queue_notification
|
||||
process_existing_batches = NotificationManager.process_existing_batches
|
||||
queue_weekly_summary = NotificationManager.queue_weekly_summary
|
||||
|
||||
@@ -117,6 +117,10 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
default=1,
|
||||
description="Cost per execution in cents after each threshold.",
|
||||
)
|
||||
execution_counter_expiration_time: int = Field(
|
||||
default=60 * 60 * 24,
|
||||
description="Time in seconds after which the execution counter is reset.",
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
|
||||
Reference in New Issue
Block a user