feat(backend): low balance notiifcation (#9534)

<!-- Clearly explain the need for these changes: -->
For emailing, we want the user to know when an agent stopped because
their balance was too low. This is the first step of that.

### Changes 🏗️
- Raise InsufficientBalanceError from credit system rather than value
error when user runs out of money
- Handle when an agent output isn't hooked up well
- Fix the contents of the email for low balance to be a bit more aligned
with the PRD
- expose the topup intent from the db manager
- objectify the execution stats so we can pass it around a bit more type
safe
- extract the notification stuff in manager into a function
<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
  - [x] Set balance to $0.01
  - [x] Run an agent that costs something more than $0.01
  - [x] Check you get an email
  - [x] Check your top up link works

---------

Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
This commit is contained in:
Nicholas Tindle
2025-03-04 23:16:57 -06:00
committed by GitHub
parent 27a5635607
commit 265a9265f7
12 changed files with 400 additions and 119 deletions

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Opti
from pydantic import BaseModel, SecretStr
from backend.data.model import NodeExecutionStats
from backend.integrations.providers import ProviderName
if TYPE_CHECKING:
@@ -711,10 +712,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
)
response_text = llm_response.response
self.merge_stats(
{
"input_token_count": llm_response.prompt_tokens,
"output_token_count": llm_response.completion_tokens,
}
NodeExecutionStats(
input_token_count=llm_response.prompt_tokens,
output_token_count=llm_response.completion_tokens,
)
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
@@ -757,10 +758,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
retry_prompt = f"Error calling LLM: {e}"
finally:
self.merge_stats(
{
"llm_call_count": retry_count + 1,
"llm_retry_count": retry_count,
}
NodeExecutionStats(
llm_call_count=retry_count + 1,
llm_retry_count=retry_count,
)
)
raise RuntimeError(retry_prompt)

View File

@@ -19,6 +19,7 @@ import jsonschema
from prisma.models import AgentBlock
from pydantic import BaseModel
from backend.data.model import NodeExecutionStats
from backend.util import json
from backend.util.settings import Config
@@ -316,7 +317,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
self.static_output = static_output
self.block_type = block_type
self.webhook_config = webhook_config
self.execution_stats = {}
self.execution_stats: NodeExecutionStats = NodeExecutionStats()
if self.webhook_config:
if isinstance(self.webhook_config, BlockWebhookConfig):
@@ -394,18 +395,29 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]):
return data
raise ValueError(f"{self.name} did not produce any output for {output}")
def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]:
for key, value in stats.items():
if isinstance(value, dict):
self.execution_stats.setdefault(key, {}).update(value)
elif isinstance(value, (int, float)):
self.execution_stats.setdefault(key, 0)
self.execution_stats[key] += value
elif isinstance(value, list):
self.execution_stats.setdefault(key, [])
self.execution_stats[key].extend(value)
def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats:
stats_dict = stats.model_dump()
current_stats = self.execution_stats.model_dump()
for key, value in stats_dict.items():
if key not in current_stats:
# Field doesn't exist yet, just set it, but this will probably
# not happen, just in case though so we throw for invalid when
# converting back in
current_stats[key] = value
elif isinstance(value, dict) and isinstance(current_stats[key], dict):
current_stats[key].update(value)
elif isinstance(value, (int, float)) and isinstance(
current_stats[key], (int, float)
):
current_stats[key] += value
elif isinstance(value, list) and isinstance(current_stats[key], list):
current_stats[key].extend(value)
else:
self.execution_stats[key] = value
current_stats[key] = value
self.execution_stats = NodeExecutionStats(**current_stats)
return self.execution_stats
@property

View File

@@ -32,6 +32,7 @@ from backend.data.model import (
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.user import get_user_by_id
from backend.notifications import NotificationManager
from backend.util.exceptions import InsufficientBalanceError
from backend.util.service import get_service_client
from backend.util.settings import Settings
@@ -313,9 +314,13 @@ class UserCreditBase(ABC):
if amount < 0 and user_balance + amount < 0:
if fail_insufficient_credits:
raise ValueError(
f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}"
raise InsufficientBalanceError(
message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}",
user_id=user_id,
balance=user_balance,
amount=amount,
)
amount = min(-user_balance, 0)
# Create the transaction

View File

@@ -15,6 +15,7 @@ from pydantic import BaseModel
from backend.data.block import BlockData, BlockInput, CompletedBlockOutput
from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.queue import AsyncRedisEventBus, RedisEventBus
from backend.server.v2.store.exceptions import DatabaseError
from backend.util import mock, type
@@ -282,13 +283,16 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu
async def update_graph_execution_stats(
graph_exec_id: str,
status: ExecutionStatus,
stats: dict[str, Any],
stats: GraphExecutionStats,
) -> ExecutionResult:
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
res = await AgentGraphExecution.prisma().update(
where={"id": graph_exec_id},
data={
"executionStatus": status,
"stats": Json(stats),
"stats": Json(data),
},
)
if not res:
@@ -297,10 +301,13 @@ async def update_graph_execution_stats(
return ExecutionResult.from_graph(res)
async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]):
async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats):
data = stats.model_dump()
if isinstance(data["error"], Exception):
data["error"] = str(data["error"])
await AgentNodeExecution.prisma().update(
where={"id": node_exec_id},
data={"stats": Json(stats)},
data={"stats": Json(data)},
)

View File

@@ -186,7 +186,9 @@ class GraphExecution(GraphExecutionMeta):
outputs: dict[str, list] = defaultdict(list)
for exec in node_executions:
if exec.block_id == _OUTPUT_BLOCK_ID:
outputs[exec.input_data["name"]].append(exec.input_data["value"])
outputs[exec.input_data["name"]].append(
exec.input_data.get("value", None)
)
return GraphExecution(
**{

View File

@@ -402,3 +402,37 @@ class RefundRequest(BaseModel):
status: str
created_at: datetime
updated_at: datetime
class NodeExecutionStats(BaseModel):
"""Execution statistics for a node execution."""
class Config:
arbitrary_types_allowed = True
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
cost: float = 0
input_size: int = 0
output_size: int = 0
llm_call_count: int = 0
llm_retry_count: int = 0
input_token_count: int = 0
output_token_count: int = 0
class GraphExecutionStats(BaseModel):
"""Execution statistics for a graph execution."""
class Config:
arbitrary_types_allowed = True
error: Optional[Exception | str] = None
walltime: float = 0
cputime: float = 0
nodes_walltime: float = 0
nodes_cputime: float = 0
node_count: int = 0
node_error_count: int = 0
cost: float = 0

View File

@@ -49,10 +49,12 @@ class ZeroBalanceData(BaseNotificationData):
class LowBalanceData(BaseNotificationData):
current_balance: float
threshold_amount: float
top_up_link: str
recent_usage: float = Field(..., description="Usage in the last 24 hours")
agent_name: str = Field(..., description="Name of the agent")
current_balance: float = Field(
..., description="Current balance in credits (100 = $1)"
)
billing_page_link: str = Field(..., description="Link to billing page")
shortfall: float = Field(..., description="Amount of credits needed to continue")
class BlockExecutionFailedData(BaseNotificationData):
@@ -197,7 +199,7 @@ class NotificationTypeOverride:
NotificationType.AGENT_RUN: QueueType.IMMEDIATE,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.BACKOFF,
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
NotificationType.DAILY_SUMMARY: QueueType.DAILY,

View File

@@ -13,11 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast
from redis.lock import Lock as RedisLock
from backend.blocks.basic import AgentOutputBlock
from backend.data.model import GraphExecutionStats, NodeExecutionStats
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventDTO,
NotificationType,
)
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManager
@@ -117,7 +120,7 @@ def execute_node(
db_client: "DatabaseManager",
creds_manager: IntegrationCredentialsManager,
data: NodeExecutionEntry,
execution_stats: dict[str, Any] | None = None,
execution_stats: NodeExecutionStats | None = None,
) -> ExecutionStream:
"""
Execute a node in the graph. This will trigger a block execution on a node,
@@ -126,7 +129,6 @@ def execute_node(
Args:
db_client: The client to send execution updates to the server.
creds_manager: The manager to acquire and release credentials.
notification_service: The service to send notifications.
data: The execution data for executing the current node.
execution_stats: The execution statistics to be updated.
@@ -256,10 +258,12 @@ def execute_node(
# Update execution stats
if execution_stats is not None:
execution_stats.update(node_block.execution_stats)
execution_stats["input_size"] = input_size
execution_stats["output_size"] = output_size
execution_stats["cost"] = cost
execution_stats = execution_stats.model_copy(
update=node_block.execution_stats.model_dump()
)
execution_stats.input_size = input_size
execution_stats.output_size = output_size
execution_stats.cost = cost
def _enqueue_next_nodes(
@@ -476,7 +480,6 @@ class Executor:
cls.pid = os.getpid()
cls.db_client = get_db_client()
cls.creds_manager = IntegrationCredentialsManager()
cls.notification_service = get_notification_service()
# Set up shutdown handlers
cls.shutdown_lock = threading.Lock()
@@ -517,7 +520,7 @@ class Executor:
cls,
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
) -> dict[str, Any]:
) -> NodeExecutionStats:
log_metadata = LogMetadata(
user_id=node_exec.user_id,
graph_eid=node_exec.graph_exec_id,
@@ -527,13 +530,15 @@ class Executor:
block_name="-",
)
execution_stats = {}
execution_stats = NodeExecutionStats()
timing_info, _ = cls._on_node_execution(
q, node_exec, log_metadata, execution_stats
)
execution_stats["walltime"] = timing_info.wall_time
execution_stats["cputime"] = timing_info.cpu_time
execution_stats.walltime = timing_info.wall_time
execution_stats.cputime = timing_info.cpu_time
if isinstance(execution_stats.error, Exception):
execution_stats.error = str(execution_stats.error)
cls.db_client.update_node_execution_stats(
node_exec.node_exec_id, execution_stats
)
@@ -546,7 +551,7 @@ class Executor:
q: ExecutionQueue[NodeExecutionEntry],
node_exec: NodeExecutionEntry,
log_metadata: LogMetadata,
stats: dict[str, Any] | None = None,
stats: NodeExecutionStats | None = None,
):
try:
log_metadata.info(f"Start node execution {node_exec.node_exec_id}")
@@ -569,6 +574,9 @@ class Executor:
f"Failed node execution {node_exec.node_exec_id}: {e}"
)
if stats is not None:
stats.error = e
@classmethod
def on_graph_executor_start(cls):
configure_logging()
@@ -577,6 +585,7 @@ class Executor:
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(
f"Graph executor {cls.pid} started with {cls.pool_size} node workers"
@@ -618,9 +627,12 @@ class Executor:
timing_info, (exec_stats, status, error) = cls._on_graph_execution(
graph_exec, cancel, log_metadata
)
exec_stats["walltime"] = timing_info.wall_time
exec_stats["cputime"] = timing_info.cpu_time
exec_stats["error"] = str(error) if error else None
exec_stats.walltime = timing_info.wall_time
exec_stats.cputime = timing_info.cpu_time
exec_stats.error = error
if isinstance(exec_stats.error, Exception):
exec_stats.error = str(exec_stats.error)
result = cls.db_client.update_graph_execution_stats(
graph_exec_id=graph_exec.graph_exec_id,
status=status,
@@ -628,48 +640,7 @@ class Executor:
)
cls.db_client.send_execution_update(result)
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
assert metadata is not None
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
# Collect named outputs as a list of dictionaries
named_outputs = []
for output in outputs:
if output.block_id == AgentOutputBlock().id:
# Create a dictionary for this named output
named_output = {
# Include the name as a field in each output
"name": (
output.output_data["name"][0]
if isinstance(output.output_data["name"], list)
else output.output_data["name"]
)
}
# Add all other fields
for key, value in output.output_data.items():
if key != "name":
named_output[key] = value
named_outputs.append(named_output)
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name,
credits_used=exec_stats["cost"],
execution_time=timing_info.wall_time,
graph_id=graph_exec.graph_id,
node_count=exec_stats["node_count"],
).model_dump(),
)
logger.info(f"Sending notification for {event}")
get_notification_service().queue_notification(event)
cls._handle_agent_run_notif(graph_exec, exec_stats)
@classmethod
@time_measured
@@ -678,7 +649,7 @@ class Executor:
graph_exec: GraphExecutionEntry,
cancel: threading.Event,
log_metadata: LogMetadata,
) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]:
) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]:
"""
Returns:
dict: The execution statistics of the graph execution.
@@ -686,12 +657,7 @@ class Executor:
Exception | None: The error that occurred during the execution, if any.
"""
log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}")
exec_stats = {
"nodes_walltime": 0,
"nodes_cputime": 0,
"node_count": 0,
"cost": 0,
}
exec_stats = GraphExecutionStats()
error = None
finished = False
@@ -717,18 +683,26 @@ class Executor:
queue.add(node_exec)
running_executions: dict[str, AsyncResult] = {}
low_balance_error: Optional[InsufficientBalanceError] = None
def make_exec_callback(exec_data: NodeExecutionEntry):
node_id = exec_data.node_id
def callback(result: object):
running_executions.pop(node_id)
nonlocal exec_stats
if isinstance(result, dict):
exec_stats["node_count"] += 1
exec_stats["nodes_cputime"] += result.get("cputime", 0)
exec_stats["nodes_walltime"] += result.get("walltime", 0)
exec_stats["cost"] += result.get("cost", 0)
running_executions.pop(exec_data.node_id)
if not isinstance(result, NodeExecutionStats):
return
nonlocal exec_stats, low_balance_error
exec_stats.node_count += 1
exec_stats.nodes_cputime += result.cputime
exec_stats.nodes_walltime += result.walltime
exec_stats.cost += result.cost
if (err := result.error) and isinstance(err, Exception):
exec_stats.node_error_count += 1
if isinstance(err, InsufficientBalanceError):
low_balance_error = err
return callback
@@ -773,6 +747,16 @@ class Executor:
execution.wait(3)
log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}")
if isinstance(low_balance_error, InsufficientBalanceError):
cls._handle_low_balance_notif(
graph_exec.user_id,
graph_exec.graph_id,
exec_stats,
low_balance_error,
)
raise low_balance_error
except Exception as e:
log_metadata.exception(
f"Failed graph execution {graph_exec.graph_exec_id}: {e}"
@@ -791,6 +775,67 @@ class Executor:
error,
)
@classmethod
def _handle_agent_run_notif(
cls,
graph_exec: GraphExecutionEntry,
exec_stats: GraphExecutionStats,
):
metadata = cls.db_client.get_graph_metadata(
graph_exec.graph_id, graph_exec.graph_version
)
outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id)
named_outputs = [
{
key: value[0] if key == "name" else value
for key, value in output.output_data.items()
}
for output in outputs
if output.block_id == AgentOutputBlock().id
]
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
)
cls.notification_service.queue_notification(event)
@classmethod
def _handle_low_balance_notif(
cls,
user_id: str,
graph_id: str,
exec_stats: GraphExecutionStats,
e: InsufficientBalanceError,
):
shortfall = e.balance - e.amount
metadata = cls.db_client.get_graph_metadata(graph_id)
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
cls.notification_service.queue_notification(
NotificationEventDTO(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
current_balance=exec_stats.cost,
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
).model_dump(),
)
)
class ExecutionManager(AppService):
def __init__(self):

View File

@@ -0,0 +1,114 @@
{# Low Balance Notification Email Template #}
{# Template variables:
data.agent_name: the name of the agent
data.current_balance: the current balance of the user
data.billing_page_link: the link to the billing page
data.shortfall: the shortfall amount
#}
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance Warning</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 165%;
margin-top: 0;
margin-bottom: 20px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" has been stopped due to low balance.
</p>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #5D23BB;
background-color: #f8f8ff;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Current Balance:</strong> ${{ "{:.2f}".format((data.current_balance|float)/100) }}
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Shortfall:</strong> ${{ "{:.2f}".format((data.shortfall|float)/100) }}
</p>
</div>
<div style="
margin-left: 15px;
margin-bottom: 20px;
padding: 15px;
border-left: 4px solid #FF6B6B;
background-color: #FFF0F0;
">
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 10px;
">
<strong>Low Balance:</strong>
</p>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
margin-top: 0;
margin-bottom: 5px;
">
Your agent "<strong>{{ data.agent_name }}</strong>" requires additional credits to continue running. The current operation has been canceled until your balance is replenished.
</p>
</div>
<div style="
text-align: center;
margin: 30px 0;
">
<a href="{{ data.billing_page_link }}" style="
font-family: 'Poppins', sans-serif;
background-color: #5D23BB;
color: white;
padding: 12px 24px;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
display: inline-block;
">
Manage Billing
</a>
</div>
<p style="
font-family: 'Poppins', sans-serif;
color: #070629;
font-size: 16px;
line-height: 150%;
margin-top: 30px;
margin-bottom: 10px;
font-style: italic;
">
This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically.
</p>

View File

@@ -4,3 +4,22 @@ class MissingConfigError(Exception):
class NeedConfirmation(Exception):
"""The user must explicitly confirm that they want to proceed"""
class InsufficientBalanceError(ValueError):
user_id: str
message: str
balance: float
amount: float
def __init__(self, message: str, user_id: str, balance: float, amount: float):
super().__init__(message)
self.args = (message, user_id, balance, amount)
self.message = message
self.user_id = user_id
self.balance = balance
self.amount = amount
def __str__(self):
"""Used to display the error message in the frontend, because we str() the error when sending the execution update"""
return self.message

View File

@@ -42,6 +42,7 @@ from Pyro5 import api as pyro
from Pyro5 import config as pyro_config
from backend.data import db, rabbitmq, redis
from backend.util.exceptions import InsufficientBalanceError
from backend.util.json import to_dict
from backend.util.process import AppProcess
from backend.util.retry import conn_retry
@@ -251,6 +252,7 @@ EXCEPTION_MAPPING = {
ValueError,
TimeoutError,
ConnectionError,
InsufficientBalanceError,
]
}
@@ -441,6 +443,7 @@ def fastapi_get_service_client(service_type: Type[AS]) -> AS:
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error in {method_name}: {e.response.text}")
error = RemoteCallError.model_validate(e.response.json(), strict=False)
# DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception
raise EXCEPTION_MAPPING.get(error.type, Exception)(
*(error.args or [str(e)])
)

View File

@@ -21,46 +21,83 @@ class TextFormatter:
self.env.globals.clear()
# Instead of clearing all filters, just remove potentially unsafe ones
unsafe_filters = ["pprint", "urlize", "xmlattr", "tojson"]
unsafe_filters = ["pprint", "tojson", "urlize", "xmlattr"]
for f in unsafe_filters:
if f in self.env.filters:
del self.env.filters[f]
self.env.filters["format"] = format_filter_for_jinja2
# Define allowed CSS properties
# Define allowed CSS properties (sorted alphabetically, if you add more)
allowed_css_properties = [
"font-family",
"background-color",
"border",
"border-bottom",
"border-color",
"border-left",
"border-radius",
"border-right",
"border-style",
"border-top",
"border-width",
"bottom",
"box-shadow",
"clear",
"color",
"display",
"float",
"font-family",
"font-size",
"font-weight",
"height",
"left",
"letter-spacing",
"line-height",
"margin-top",
"margin-bottom",
"margin-left",
"margin-right",
"background-color",
"margin-top",
"padding",
"border-radius",
"font-weight",
"position",
"right",
"text-align",
"text-decoration",
"text-shadow",
"text-transform",
"top",
"width",
]
self.css_sanitizer = CSSSanitizer(allowed_css_properties=allowed_css_properties)
# Define allowed tags (sorted alphabetically, if you add more)
self.allowed_tags = [
"p",
"a",
"b",
"br",
"div",
"em",
"h1",
"h2",
"h3",
"h4",
"h5",
"i",
"img",
"li",
"p",
"span",
"strong",
"u",
"ul",
"li",
"br",
"strong",
"em",
"div",
"span",
]
self.allowed_attributes = {"*": ["style", "class"]}
# Define allowed attributes to be used on specific tags
self.allowed_attributes = {
"*": ["class", "style"],
"a": ["href"],
"img": ["src"],
}
def format_string(self, template_str: str, values=None, **kwargs) -> str:
"""Regular template rendering with escaping"""