mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
feat(backend): low balance notiifcation (#9534)
<!-- Clearly explain the need for these changes: --> For emailing, we want the user to know when an agent stopped because their balance was too low. This is the first step of that. ### Changes 🏗️ - Raise InsufficientBalanceError from credit system rather than value error when user runs out of money - Handle when an agent output isn't hooked up well - Fix the contents of the email for low balance to be a bit more aligned with the PRD - expose the topup intent from the db manager - objectify the execution stats so we can pass it around a bit more type safe - extract the notification stuff in manager into a function <!-- Concisely describe all of the changes made in this pull request: --> ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: <!-- Put your test plan here: --> - [x] Set balance to $0.01 - [x] Run an agent that costs something more than $0.01 - [x] Check you get an email - [x] Check your top up link works --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Opti
|
||||
|
||||
from pydantic import BaseModel, SecretStr
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -711,10 +712,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
)
|
||||
response_text = llm_response.response
|
||||
self.merge_stats(
|
||||
{
|
||||
"input_token_count": llm_response.prompt_tokens,
|
||||
"output_token_count": llm_response.completion_tokens,
|
||||
}
|
||||
NodeExecutionStats(
|
||||
input_token_count=llm_response.prompt_tokens,
|
||||
output_token_count=llm_response.completion_tokens,
|
||||
)
|
||||
)
|
||||
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
|
||||
|
||||
@@ -757,10 +758,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
|
||||
retry_prompt = f"Error calling LLM: {e}"
|
||||
finally:
|
||||
self.merge_stats(
|
||||
{
|
||||
"llm_call_count": retry_count + 1,
|
||||
"llm_retry_count": retry_count,
|
||||
}
|
||||
NodeExecutionStats(
|
||||
llm_call_count=retry_count + 1,
|
||||
llm_retry_count=retry_count,
|
||||
)
|
||||
)
|
||||
|
||||
raise RuntimeError(retry_prompt)
|
||||
|
||||
@@ -19,6 +19,7 @@ import jsonschema
|
||||
from prisma.models import AgentBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.model import NodeExecutionStats
|
||||
from backend.util import json
|
||||
from backend.util.settings import Config
|
||||
|
||||
@@ -316,7 +317,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
self.static_output = static_output
|
||||
self.block_type = block_type
|
||||
self.webhook_config = webhook_config
|
||||
self.execution_stats = {}
|
||||
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
|
||||
|
||||
if self.webhook_config:
|
||||
if isinstance(self.webhook_config, BlockWebhookConfig):
|
||||
@@ -394,18 +395,29 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
|
||||
return data
|
||||
raise ValueError(f"{self.name} did not produce any output for {output}")
|
||||
|
||||
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in stats.items():
|
||||
if isinstance(value, dict):
|
||||
self.execution_stats.setdefault(key, {}).update(value)
|
||||
elif isinstance(value, (int, float)):
|
||||
self.execution_stats.setdefault(key, 0)
|
||||
self.execution_stats[key] += value
|
||||
elif isinstance(value, list):
|
||||
self.execution_stats.setdefault(key, [])
|
||||
self.execution_stats[key].extend(value)
|
||||
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
|
||||
stats_dict = stats.model_dump()
|
||||
current_stats = self.execution_stats.model_dump()
|
||||
|
||||
for key, value in stats_dict.items():
|
||||
if key not in current_stats:
|
||||
# Field doesn't exist yet, just set it, but this will probably
|
||||
# not happen, just in case though so we throw for invalid when
|
||||
# converting back in
|
||||
current_stats[key] = value
|
||||
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
|
||||
current_stats[key].update(value)
|
||||
elif isinstance(value, (int, float)) and isinstance(
|
||||
current_stats[key], (int, float)
|
||||
):
|
||||
current_stats[key] += value
|
||||
elif isinstance(value, list) and isinstance(current_stats[key], list):
|
||||
current_stats[key].extend(value)
|
||||
else:
|
||||
self.execution_stats[key] = value
|
||||
current_stats[key] = value
|
||||
|
||||
self.execution_stats = NodeExecutionStats(**current_stats)
|
||||
|
||||
return self.execution_stats
|
||||
|
||||
@property
|
||||
|
||||
@@ -32,6 +32,7 @@ from backend.data.model import (
|
||||
from backend.data.notifications import NotificationEventDTO, RefundRequestData
|
||||
from backend.data.user import get_user_by_id
|
||||
from backend.notifications import NotificationManager
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.service import get_service_client
|
||||
from backend.util.settings import Settings
|
||||
|
||||
@@ -313,9 +314,13 @@ class UserCreditBase(ABC):
|
||||
|
||||
if amount < 0 and user_balance + amount < 0:
|
||||
if fail_insufficient_credits:
|
||||
raise ValueError(
|
||||
f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}"
|
||||
raise InsufficientBalanceError(
|
||||
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
|
||||
user_id=user_id,
|
||||
balance=user_balance,
|
||||
amount=amount,
|
||||
)
|
||||
|
||||
amount = min(-user_balance, 0)
|
||||
|
||||
# Create the transaction
|
||||
|
||||
@@ -15,6 +15,7 @@ from pydantic import BaseModel
|
||||
|
||||
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
|
||||
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
|
||||
from backend.server.v2.store.exceptions import DatabaseError
|
||||
from backend.util import mock, type
|
||||
@@ -282,13 +283,16 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
|
||||
async def update_graph_execution_stats(
|
||||
graph_exec_id: str,
|
||||
status: ExecutionStatus,
|
||||
stats: dict[str, Any],
|
||||
stats: GraphExecutionStats,
|
||||
) -> ExecutionResult:
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data["error"] = str(data["error"])
|
||||
res = await AgentGraphExecution.prisma().update(
|
||||
where={"id": graph_exec_id},
|
||||
data={
|
||||
"executionStatus": status,
|
||||
"stats": Json(stats),
|
||||
"stats": Json(data),
|
||||
},
|
||||
)
|
||||
if not res:
|
||||
@@ -297,10 +301,13 @@ async def update_graph_execution_stats(
|
||||
return ExecutionResult.from_graph(res)
|
||||
|
||||
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
|
||||
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
|
||||
data = stats.model_dump()
|
||||
if isinstance(data["error"], Exception):
|
||||
data["error"] = str(data["error"])
|
||||
await AgentNodeExecution.prisma().update(
|
||||
where={"id": node_exec_id},
|
||||
data={"stats": Json(stats)},
|
||||
data={"stats": Json(data)},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -186,7 +186,9 @@ class GraphExecution(GraphExecutionMeta):
|
||||
outputs: dict[str, list] = defaultdict(list)
|
||||
for exec in node_executions:
|
||||
if exec.block_id == _OUTPUT_BLOCK_ID:
|
||||
outputs[exec.input_data["name"]].append(exec.input_data["value"])
|
||||
outputs[exec.input_data["name"]].append(
|
||||
exec.input_data.get("value", None)
|
||||
)
|
||||
|
||||
return GraphExecution(
|
||||
**{
|
||||
|
||||
@@ -402,3 +402,37 @@ class RefundRequest(BaseModel):
|
||||
status: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class NodeExecutionStats(BaseModel):
|
||||
"""Execution statistics for a node execution."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
cputime: float = 0
|
||||
cost: float = 0
|
||||
input_size: int = 0
|
||||
output_size: int = 0
|
||||
llm_call_count: int = 0
|
||||
llm_retry_count: int = 0
|
||||
input_token_count: int = 0
|
||||
output_token_count: int = 0
|
||||
|
||||
|
||||
class GraphExecutionStats(BaseModel):
|
||||
"""Execution statistics for a graph execution."""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
error: Optional[Exception | str] = None
|
||||
walltime: float = 0
|
||||
cputime: float = 0
|
||||
nodes_walltime: float = 0
|
||||
nodes_cputime: float = 0
|
||||
node_count: int = 0
|
||||
node_error_count: int = 0
|
||||
cost: float = 0
|
||||
|
||||
@@ -49,10 +49,12 @@ class ZeroBalanceData(BaseNotificationData):
|
||||
|
||||
|
||||
class LowBalanceData(BaseNotificationData):
|
||||
current_balance: float
|
||||
threshold_amount: float
|
||||
top_up_link: str
|
||||
recent_usage: float = Field(..., description="Usage in the last 24 hours")
|
||||
agent_name: str = Field(..., description="Name of the agent")
|
||||
current_balance: float = Field(
|
||||
..., description="Current balance in credits (100 = $1)"
|
||||
)
|
||||
billing_page_link: str = Field(..., description="Link to billing page")
|
||||
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
||||
|
||||
|
||||
class BlockExecutionFailedData(BaseNotificationData):
|
||||
@@ -197,7 +199,7 @@ class NotificationTypeOverride:
|
||||
NotificationType.AGENT_RUN: QueueType.IMMEDIATE,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.LOW_BALANCE: QueueType.BACKOFF,
|
||||
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
||||
NotificationType.DAILY_SUMMARY: QueueType.DAILY,
|
||||
|
||||
@@ -13,11 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from backend.blocks.basic import AgentOutputBlock
|
||||
from backend.data.model import GraphExecutionStats, NodeExecutionStats
|
||||
from backend.data.notifications import (
|
||||
AgentRunData,
|
||||
LowBalanceData,
|
||||
NotificationEventDTO,
|
||||
NotificationType,
|
||||
)
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.executor import DatabaseManager
|
||||
@@ -117,7 +120,7 @@ def execute_node(
|
||||
db_client: "DatabaseManager",
|
||||
creds_manager: IntegrationCredentialsManager,
|
||||
data: NodeExecutionEntry,
|
||||
execution_stats: dict[str, Any] | None = None,
|
||||
execution_stats: NodeExecutionStats | None = None,
|
||||
) -> ExecutionStream:
|
||||
"""
|
||||
Execute a node in the graph. This will trigger a block execution on a node,
|
||||
@@ -126,7 +129,6 @@ def execute_node(
|
||||
Args:
|
||||
db_client: The client to send execution updates to the server.
|
||||
creds_manager: The manager to acquire and release credentials.
|
||||
notification_service: The service to send notifications.
|
||||
data: The execution data for executing the current node.
|
||||
execution_stats: The execution statistics to be updated.
|
||||
|
||||
@@ -256,10 +258,12 @@ def execute_node(
|
||||
|
||||
# Update execution stats
|
||||
if execution_stats is not None:
|
||||
execution_stats.update(node_block.execution_stats)
|
||||
execution_stats["input_size"] = input_size
|
||||
execution_stats["output_size"] = output_size
|
||||
execution_stats["cost"] = cost
|
||||
execution_stats = execution_stats.model_copy(
|
||||
update=node_block.execution_stats.model_dump()
|
||||
)
|
||||
execution_stats.input_size = input_size
|
||||
execution_stats.output_size = output_size
|
||||
execution_stats.cost = cost
|
||||
|
||||
|
||||
def _enqueue_next_nodes(
|
||||
@@ -476,7 +480,6 @@ class Executor:
|
||||
cls.pid = os.getpid()
|
||||
cls.db_client = get_db_client()
|
||||
cls.creds_manager = IntegrationCredentialsManager()
|
||||
cls.notification_service = get_notification_service()
|
||||
|
||||
# Set up shutdown handlers
|
||||
cls.shutdown_lock = threading.Lock()
|
||||
@@ -517,7 +520,7 @@ class Executor:
|
||||
cls,
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
) -> dict[str, Any]:
|
||||
) -> NodeExecutionStats:
|
||||
log_metadata = LogMetadata(
|
||||
user_id=node_exec.user_id,
|
||||
graph_eid=node_exec.graph_exec_id,
|
||||
@@ -527,13 +530,15 @@ class Executor:
|
||||
block_name="-",
|
||||
)
|
||||
|
||||
execution_stats = {}
|
||||
execution_stats = NodeExecutionStats()
|
||||
timing_info, _ = cls._on_node_execution(
|
||||
q, node_exec, log_metadata, execution_stats
|
||||
)
|
||||
execution_stats["walltime"] = timing_info.wall_time
|
||||
execution_stats["cputime"] = timing_info.cpu_time
|
||||
execution_stats.walltime = timing_info.wall_time
|
||||
execution_stats.cputime = timing_info.cpu_time
|
||||
|
||||
if isinstance(execution_stats.error, Exception):
|
||||
execution_stats.error = str(execution_stats.error)
|
||||
cls.db_client.update_node_execution_stats(
|
||||
node_exec.node_exec_id, execution_stats
|
||||
)
|
||||
@@ -546,7 +551,7 @@ class Executor:
|
||||
q: ExecutionQueue[NodeExecutionEntry],
|
||||
node_exec: NodeExecutionEntry,
|
||||
log_metadata: LogMetadata,
|
||||
stats: dict[str, Any] | None = None,
|
||||
stats: NodeExecutionStats | None = None,
|
||||
):
|
||||
try:
|
||||
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
|
||||
@@ -569,6 +574,9 @@ class Executor:
|
||||
f"Failed node execution {node_exec.node_exec_id}: {e}"
|
||||
)
|
||||
|
||||
if stats is not None:
|
||||
stats.error = e
|
||||
|
||||
@classmethod
|
||||
def on_graph_executor_start(cls):
|
||||
configure_logging()
|
||||
@@ -577,6 +585,7 @@ class Executor:
|
||||
cls.db_client = get_db_client()
|
||||
cls.pool_size = settings.config.num_node_workers
|
||||
cls.pid = os.getpid()
|
||||
cls.notification_service = get_notification_service()
|
||||
cls._init_node_executor_pool()
|
||||
logger.info(
|
||||
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
|
||||
@@ -618,9 +627,12 @@ class Executor:
|
||||
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
|
||||
graph_exec, cancel, log_metadata
|
||||
)
|
||||
exec_stats["walltime"] = timing_info.wall_time
|
||||
exec_stats["cputime"] = timing_info.cpu_time
|
||||
exec_stats["error"] = str(error) if error else None
|
||||
exec_stats.walltime = timing_info.wall_time
|
||||
exec_stats.cputime = timing_info.cpu_time
|
||||
exec_stats.error = error
|
||||
|
||||
if isinstance(exec_stats.error, Exception):
|
||||
exec_stats.error = str(exec_stats.error)
|
||||
result = cls.db_client.update_graph_execution_stats(
|
||||
graph_exec_id=graph_exec.graph_exec_id,
|
||||
status=status,
|
||||
@@ -628,48 +640,7 @@ class Executor:
|
||||
)
|
||||
cls.db_client.send_execution_update(result)
|
||||
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
assert metadata is not None
|
||||
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
|
||||
|
||||
# Collect named outputs as a list of dictionaries
|
||||
named_outputs = []
|
||||
for output in outputs:
|
||||
if output.block_id == AgentOutputBlock().id:
|
||||
# Create a dictionary for this named output
|
||||
named_output = {
|
||||
# Include the name as a field in each output
|
||||
"name": (
|
||||
output.output_data["name"][0]
|
||||
if isinstance(output.output_data["name"], list)
|
||||
else output.output_data["name"]
|
||||
)
|
||||
}
|
||||
|
||||
# Add all other fields
|
||||
for key, value in output.output_data.items():
|
||||
if key != "name":
|
||||
named_output[key] = value
|
||||
|
||||
named_outputs.append(named_output)
|
||||
|
||||
event = NotificationEventDTO(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name,
|
||||
credits_used=exec_stats["cost"],
|
||||
execution_time=timing_info.wall_time,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats["node_count"],
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
logger.info(f"Sending notification for {event}")
|
||||
get_notification_service().queue_notification(event)
|
||||
cls._handle_agent_run_notif(graph_exec, exec_stats)
|
||||
|
||||
@classmethod
|
||||
@time_measured
|
||||
@@ -678,7 +649,7 @@ class Executor:
|
||||
graph_exec: GraphExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
log_metadata: LogMetadata,
|
||||
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
|
||||
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
|
||||
"""
|
||||
Returns:
|
||||
dict: The execution statistics of the graph execution.
|
||||
@@ -686,12 +657,7 @@ class Executor:
|
||||
Exception | None: The error that occurred during the execution, if any.
|
||||
"""
|
||||
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
|
||||
exec_stats = {
|
||||
"nodes_walltime": 0,
|
||||
"nodes_cputime": 0,
|
||||
"node_count": 0,
|
||||
"cost": 0,
|
||||
}
|
||||
exec_stats = GraphExecutionStats()
|
||||
error = None
|
||||
finished = False
|
||||
|
||||
@@ -717,18 +683,26 @@ class Executor:
|
||||
queue.add(node_exec)
|
||||
|
||||
running_executions: dict[str, AsyncResult] = {}
|
||||
low_balance_error: Optional[InsufficientBalanceError] = None
|
||||
|
||||
def make_exec_callback(exec_data: NodeExecutionEntry):
|
||||
node_id = exec_data.node_id
|
||||
|
||||
def callback(result: object):
|
||||
running_executions.pop(node_id)
|
||||
nonlocal exec_stats
|
||||
if isinstance(result, dict):
|
||||
exec_stats["node_count"] += 1
|
||||
exec_stats["nodes_cputime"] += result.get("cputime", 0)
|
||||
exec_stats["nodes_walltime"] += result.get("walltime", 0)
|
||||
exec_stats["cost"] += result.get("cost", 0)
|
||||
running_executions.pop(exec_data.node_id)
|
||||
|
||||
if not isinstance(result, NodeExecutionStats):
|
||||
return
|
||||
|
||||
nonlocal exec_stats, low_balance_error
|
||||
exec_stats.node_count += 1
|
||||
exec_stats.nodes_cputime += result.cputime
|
||||
exec_stats.nodes_walltime += result.walltime
|
||||
exec_stats.cost += result.cost
|
||||
if (err := result.error) and isinstance(err, Exception):
|
||||
exec_stats.node_error_count += 1
|
||||
|
||||
if isinstance(err, InsufficientBalanceError):
|
||||
low_balance_error = err
|
||||
|
||||
return callback
|
||||
|
||||
@@ -773,6 +747,16 @@ class Executor:
|
||||
execution.wait(3)
|
||||
|
||||
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
|
||||
|
||||
if isinstance(low_balance_error, InsufficientBalanceError):
|
||||
cls._handle_low_balance_notif(
|
||||
graph_exec.user_id,
|
||||
graph_exec.graph_id,
|
||||
exec_stats,
|
||||
low_balance_error,
|
||||
)
|
||||
raise low_balance_error
|
||||
|
||||
except Exception as e:
|
||||
log_metadata.exception(
|
||||
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
|
||||
@@ -791,6 +775,67 @@ class Executor:
|
||||
error,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _handle_agent_run_notif(
|
||||
cls,
|
||||
graph_exec: GraphExecutionEntry,
|
||||
exec_stats: GraphExecutionStats,
|
||||
):
|
||||
metadata = cls.db_client.get_graph_metadata(
|
||||
graph_exec.graph_id, graph_exec.graph_version
|
||||
)
|
||||
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
|
||||
|
||||
named_outputs = [
|
||||
{
|
||||
key: value[0] if key == "name" else value
|
||||
for key, value in output.output_data.items()
|
||||
}
|
||||
for output in outputs
|
||||
if output.block_id == AgentOutputBlock().id
|
||||
]
|
||||
|
||||
event = NotificationEventDTO(
|
||||
user_id=graph_exec.user_id,
|
||||
type=NotificationType.AGENT_RUN,
|
||||
data=AgentRunData(
|
||||
outputs=named_outputs,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
credits_used=exec_stats.cost,
|
||||
execution_time=exec_stats.walltime,
|
||||
graph_id=graph_exec.graph_id,
|
||||
node_count=exec_stats.node_count,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
cls.notification_service.queue_notification(event)
|
||||
|
||||
@classmethod
|
||||
def _handle_low_balance_notif(
|
||||
cls,
|
||||
user_id: str,
|
||||
graph_id: str,
|
||||
exec_stats: GraphExecutionStats,
|
||||
e: InsufficientBalanceError,
|
||||
):
|
||||
shortfall = e.balance - e.amount
|
||||
metadata = cls.db_client.get_graph_metadata(graph_id)
|
||||
base_url = (
|
||||
settings.config.frontend_base_url or settings.config.platform_base_url
|
||||
)
|
||||
cls.notification_service.queue_notification(
|
||||
NotificationEventDTO(
|
||||
user_id=user_id,
|
||||
type=NotificationType.LOW_BALANCE,
|
||||
data=LowBalanceData(
|
||||
current_balance=exec_stats.cost,
|
||||
billing_page_link=f"{base_url}/profile/credits",
|
||||
shortfall=shortfall,
|
||||
agent_name=metadata.name if metadata else "Unknown Agent",
|
||||
).model_dump(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ExecutionManager(AppService):
|
||||
def __init__(self):
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
{# Low Balance Notification Email Template #}
|
||||
{# Template variables:
|
||||
data.agent_name: the name of the agent
|
||||
data.current_balance: the current balance of the user
|
||||
data.billing_page_link: the link to the billing page
|
||||
data.shortfall: the shortfall amount
|
||||
#}
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Low Balance Warning</strong>
|
||||
</p>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 165%;
|
||||
margin-top: 0;
|
||||
margin-bottom: 20px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
|
||||
</p>
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #5D23BB;
|
||||
background-color: #f8f8ff;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
|
||||
<div style="
|
||||
margin-left: 15px;
|
||||
margin-bottom: 20px;
|
||||
padding: 15px;
|
||||
border-left: 4px solid #FF6B6B;
|
||||
background-color: #FFF0F0;
|
||||
">
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 10px;
|
||||
">
|
||||
<strong>Low Balance:</strong>
|
||||
</p>
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
margin-top: 0;
|
||||
margin-bottom: 5px;
|
||||
">
|
||||
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div style="
|
||||
text-align: center;
|
||||
margin: 30px 0;
|
||||
">
|
||||
<a href="{{ data.billing_page_link }}" style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
background-color: #5D23BB;
|
||||
color: white;
|
||||
padding: 12px 24px;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
display: inline-block;
|
||||
">
|
||||
Manage Billing
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<p style="
|
||||
font-family: 'Poppins', sans-serif;
|
||||
color: #070629;
|
||||
font-size: 16px;
|
||||
line-height: 150%;
|
||||
margin-top: 30px;
|
||||
margin-bottom: 10px;
|
||||
font-style: italic;
|
||||
">
|
||||
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
|
||||
</p>
|
||||
@@ -4,3 +4,22 @@ class MissingConfigError(Exception):
|
||||
|
||||
class NeedConfirmation(Exception):
|
||||
"""The user must explicitly confirm that they want to proceed"""
|
||||
|
||||
|
||||
class InsufficientBalanceError(ValueError):
|
||||
user_id: str
|
||||
message: str
|
||||
balance: float
|
||||
amount: float
|
||||
|
||||
def __init__(self, message: str, user_id: str, balance: float, amount: float):
|
||||
super().__init__(message)
|
||||
self.args = (message, user_id, balance, amount)
|
||||
self.message = message
|
||||
self.user_id = user_id
|
||||
self.balance = balance
|
||||
self.amount = amount
|
||||
|
||||
def __str__(self):
|
||||
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
|
||||
return self.message
|
||||
|
||||
@@ -42,6 +42,7 @@ from Pyro5 import api as pyro
|
||||
from Pyro5 import config as pyro_config
|
||||
|
||||
from backend.data import db, rabbitmq, redis
|
||||
from backend.util.exceptions import InsufficientBalanceError
|
||||
from backend.util.json import to_dict
|
||||
from backend.util.process import AppProcess
|
||||
from backend.util.retry import conn_retry
|
||||
@@ -251,6 +252,7 @@ EXCEPTION_MAPPING = {
|
||||
ValueError,
|
||||
TimeoutError,
|
||||
ConnectionError,
|
||||
InsufficientBalanceError,
|
||||
]
|
||||
}
|
||||
|
||||
@@ -441,6 +443,7 @@ def fastapi_get_service_client(service_type: Type[AS]) -> AS:
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"HTTP error in {method_name}: {e.response.text}")
|
||||
error = RemoteCallError.model_validate(e.response.json(), strict=False)
|
||||
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
|
||||
raise EXCEPTION_MAPPING.get(error.type, Exception)(
|
||||
*(error.args or [str(e)])
|
||||
)
|
||||
|
||||
@@ -21,46 +21,83 @@ class TextFormatter:
|
||||
self.env.globals.clear()
|
||||
|
||||
# Instead of clearing all filters, just remove potentially unsafe ones
|
||||
unsafe_filters = ["pprint", "urlize", "xmlattr", "tojson"]
|
||||
unsafe_filters = ["pprint", "tojson", "urlize", "xmlattr"]
|
||||
for f in unsafe_filters:
|
||||
if f in self.env.filters:
|
||||
del self.env.filters[f]
|
||||
|
||||
self.env.filters["format"] = format_filter_for_jinja2
|
||||
|
||||
# Define allowed CSS properties
|
||||
# Define allowed CSS properties (sorted alphabetically, if you add more)
|
||||
allowed_css_properties = [
|
||||
"font-family",
|
||||
"background-color",
|
||||
"border",
|
||||
"border-bottom",
|
||||
"border-color",
|
||||
"border-left",
|
||||
"border-radius",
|
||||
"border-right",
|
||||
"border-style",
|
||||
"border-top",
|
||||
"border-width",
|
||||
"bottom",
|
||||
"box-shadow",
|
||||
"clear",
|
||||
"color",
|
||||
"display",
|
||||
"float",
|
||||
"font-family",
|
||||
"font-size",
|
||||
"font-weight",
|
||||
"height",
|
||||
"left",
|
||||
"letter-spacing",
|
||||
"line-height",
|
||||
"margin-top",
|
||||
"margin-bottom",
|
||||
"margin-left",
|
||||
"margin-right",
|
||||
"background-color",
|
||||
"margin-top",
|
||||
"padding",
|
||||
"border-radius",
|
||||
"font-weight",
|
||||
"position",
|
||||
"right",
|
||||
"text-align",
|
||||
"text-decoration",
|
||||
"text-shadow",
|
||||
"text-transform",
|
||||
"top",
|
||||
"width",
|
||||
]
|
||||
|
||||
self.css_sanitizer = CSSSanitizer(allowed_css_properties=allowed_css_properties)
|
||||
|
||||
# Define allowed tags (sorted alphabetically, if you add more)
|
||||
self.allowed_tags = [
|
||||
"p",
|
||||
"a",
|
||||
"b",
|
||||
"br",
|
||||
"div",
|
||||
"em",
|
||||
"h1",
|
||||
"h2",
|
||||
"h3",
|
||||
"h4",
|
||||
"h5",
|
||||
"i",
|
||||
"img",
|
||||
"li",
|
||||
"p",
|
||||
"span",
|
||||
"strong",
|
||||
"u",
|
||||
"ul",
|
||||
"li",
|
||||
"br",
|
||||
"strong",
|
||||
"em",
|
||||
"div",
|
||||
"span",
|
||||
]
|
||||
self.allowed_attributes = {"*": ["style", "class"]}
|
||||
|
||||
# Define allowed attributes to be used on specific tags
|
||||
self.allowed_attributes = {
|
||||
"*": ["class", "style"],
|
||||
"a": ["href"],
|
||||
"img": ["src"],
|
||||
}
|
||||
|
||||
def format_string(self, template_str: str, values=None, **kwargs) -> str:
|
||||
"""Regular template rendering with escaping"""
|
||||
|
||||
Reference in New Issue
Block a user