diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index ee63d42e75..2ec54b6d6e 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Iterable, List, Literal, NamedTuple, Opti from pydantic import BaseModel, SecretStr +from backend.data.model import NodeExecutionStats from backend.integrations.providers import ProviderName if TYPE_CHECKING: @@ -711,10 +712,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): ) response_text = llm_response.response self.merge_stats( - { - "input_token_count": llm_response.prompt_tokens, - "output_token_count": llm_response.completion_tokens, - } + NodeExecutionStats( + input_token_count=llm_response.prompt_tokens, + output_token_count=llm_response.completion_tokens, + ) ) logger.info(f"LLM attempt-{retry_count} response: {response_text}") @@ -757,10 +758,10 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase): retry_prompt = f"Error calling LLM: {e}" finally: self.merge_stats( - { - "llm_call_count": retry_count + 1, - "llm_retry_count": retry_count, - } + NodeExecutionStats( + llm_call_count=retry_count + 1, + llm_retry_count=retry_count, + ) ) raise RuntimeError(retry_prompt) diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 2509dcd664..941d459edf 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -19,6 +19,7 @@ import jsonschema from prisma.models import AgentBlock from pydantic import BaseModel +from backend.data.model import NodeExecutionStats from backend.util import json from backend.util.settings import Config @@ -316,7 +317,7 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): self.static_output = static_output self.block_type = block_type self.webhook_config = webhook_config - self.execution_stats = {} + self.execution_stats: NodeExecutionStats = NodeExecutionStats() if self.webhook_config: if isinstance(self.webhook_config, BlockWebhookConfig): @@ -394,18 +395,29 @@ class Block(ABC, Generic[BlockSchemaInputType, BlockSchemaOutputType]): return data raise ValueError(f"{self.name} did not produce any output for {output}") - def merge_stats(self, stats: dict[str, Any]) -> dict[str, Any]: - for key, value in stats.items(): - if isinstance(value, dict): - self.execution_stats.setdefault(key, {}).update(value) - elif isinstance(value, (int, float)): - self.execution_stats.setdefault(key, 0) - self.execution_stats[key] += value - elif isinstance(value, list): - self.execution_stats.setdefault(key, []) - self.execution_stats[key].extend(value) + def merge_stats(self, stats: NodeExecutionStats) -> NodeExecutionStats: + stats_dict = stats.model_dump() + current_stats = self.execution_stats.model_dump() + + for key, value in stats_dict.items(): + if key not in current_stats: + # Field doesn't exist yet, just set it, but this will probably + # not happen, just in case though so we throw for invalid when + # converting back in + current_stats[key] = value + elif isinstance(value, dict) and isinstance(current_stats[key], dict): + current_stats[key].update(value) + elif isinstance(value, (int, float)) and isinstance( + current_stats[key], (int, float) + ): + current_stats[key] += value + elif isinstance(value, list) and isinstance(current_stats[key], list): + current_stats[key].extend(value) else: - self.execution_stats[key] = value + current_stats[key] = value + + self.execution_stats = NodeExecutionStats(**current_stats) + return self.execution_stats @property diff --git a/autogpt_platform/backend/backend/data/credit.py b/autogpt_platform/backend/backend/data/credit.py index db422531c4..452b60d750 100644 --- a/autogpt_platform/backend/backend/data/credit.py +++ b/autogpt_platform/backend/backend/data/credit.py @@ -32,6 +32,7 @@ from backend.data.model import ( from backend.data.notifications import NotificationEventDTO, RefundRequestData from backend.data.user import get_user_by_id from backend.notifications import NotificationManager +from backend.util.exceptions import InsufficientBalanceError from backend.util.service import get_service_client from backend.util.settings import Settings @@ -313,9 +314,13 @@ class UserCreditBase(ABC): if amount < 0 and user_balance + amount < 0: if fail_insufficient_credits: - raise ValueError( - f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}" + raise InsufficientBalanceError( + message=f"Insufficient balance of ${user_balance/100}, where this will cost ${abs(amount)/100}", + user_id=user_id, + balance=user_balance, + amount=amount, ) + amount = min(-user_balance, 0) # Create the transaction diff --git a/autogpt_platform/backend/backend/data/execution.py b/autogpt_platform/backend/backend/data/execution.py index 57dba68ab7..ca6926ee1c 100644 --- a/autogpt_platform/backend/backend/data/execution.py +++ b/autogpt_platform/backend/backend/data/execution.py @@ -15,6 +15,7 @@ from pydantic import BaseModel from backend.data.block import BlockData, BlockInput, CompletedBlockOutput from backend.data.includes import EXECUTION_RESULT_INCLUDE, GRAPH_EXECUTION_INCLUDE +from backend.data.model import GraphExecutionStats, NodeExecutionStats from backend.data.queue import AsyncRedisEventBus, RedisEventBus from backend.server.v2.store.exceptions import DatabaseError from backend.util import mock, type @@ -282,13 +283,16 @@ async def update_graph_execution_start_time(graph_exec_id: str) -> ExecutionResu async def update_graph_execution_stats( graph_exec_id: str, status: ExecutionStatus, - stats: dict[str, Any], + stats: GraphExecutionStats, ) -> ExecutionResult: + data = stats.model_dump() + if isinstance(data["error"], Exception): + data["error"] = str(data["error"]) res = await AgentGraphExecution.prisma().update( where={"id": graph_exec_id}, data={ "executionStatus": status, - "stats": Json(stats), + "stats": Json(data), }, ) if not res: @@ -297,10 +301,13 @@ async def update_graph_execution_stats( return ExecutionResult.from_graph(res) -async def update_node_execution_stats(node_exec_id: str, stats: dict[str, Any]): +async def update_node_execution_stats(node_exec_id: str, stats: NodeExecutionStats): + data = stats.model_dump() + if isinstance(data["error"], Exception): + data["error"] = str(data["error"]) await AgentNodeExecution.prisma().update( where={"id": node_exec_id}, - data={"stats": Json(stats)}, + data={"stats": Json(data)}, ) diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 18e0cb7ee6..39fdd5b6ed 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -186,7 +186,9 @@ class GraphExecution(GraphExecutionMeta): outputs: dict[str, list] = defaultdict(list) for exec in node_executions: if exec.block_id == _OUTPUT_BLOCK_ID: - outputs[exec.input_data["name"]].append(exec.input_data["value"]) + outputs[exec.input_data["name"]].append( + exec.input_data.get("value", None) + ) return GraphExecution( **{ diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index b52b13b2f7..b1f3a74dc0 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -402,3 +402,37 @@ class RefundRequest(BaseModel): status: str created_at: datetime updated_at: datetime + + +class NodeExecutionStats(BaseModel): + """Execution statistics for a node execution.""" + + class Config: + arbitrary_types_allowed = True + + error: Optional[Exception | str] = None + walltime: float = 0 + cputime: float = 0 + cost: float = 0 + input_size: int = 0 + output_size: int = 0 + llm_call_count: int = 0 + llm_retry_count: int = 0 + input_token_count: int = 0 + output_token_count: int = 0 + + +class GraphExecutionStats(BaseModel): + """Execution statistics for a graph execution.""" + + class Config: + arbitrary_types_allowed = True + + error: Optional[Exception | str] = None + walltime: float = 0 + cputime: float = 0 + nodes_walltime: float = 0 + nodes_cputime: float = 0 + node_count: int = 0 + node_error_count: int = 0 + cost: float = 0 diff --git a/autogpt_platform/backend/backend/data/notifications.py b/autogpt_platform/backend/backend/data/notifications.py index db519e103f..43134f0d3a 100644 --- a/autogpt_platform/backend/backend/data/notifications.py +++ b/autogpt_platform/backend/backend/data/notifications.py @@ -49,10 +49,12 @@ class ZeroBalanceData(BaseNotificationData): class LowBalanceData(BaseNotificationData): - current_balance: float - threshold_amount: float - top_up_link: str - recent_usage: float = Field(..., description="Usage in the last 24 hours") + agent_name: str = Field(..., description="Name of the agent") + current_balance: float = Field( + ..., description="Current balance in credits (100 = $1)" + ) + billing_page_link: str = Field(..., description="Link to billing page") + shortfall: float = Field(..., description="Amount of credits needed to continue") class BlockExecutionFailedData(BaseNotificationData): @@ -197,7 +199,7 @@ class NotificationTypeOverride: NotificationType.AGENT_RUN: QueueType.IMMEDIATE, # These are batched by the notification service, but with a backoff strategy NotificationType.ZERO_BALANCE: QueueType.BACKOFF, - NotificationType.LOW_BALANCE: QueueType.BACKOFF, + NotificationType.LOW_BALANCE: QueueType.IMMEDIATE, NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF, NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF, NotificationType.DAILY_SUMMARY: QueueType.DAILY, diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 4b92903a9b..2a3100166f 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -13,11 +13,14 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, cast from redis.lock import Lock as RedisLock from backend.blocks.basic import AgentOutputBlock +from backend.data.model import GraphExecutionStats, NodeExecutionStats from backend.data.notifications import ( AgentRunData, + LowBalanceData, NotificationEventDTO, NotificationType, ) +from backend.util.exceptions import InsufficientBalanceError if TYPE_CHECKING: from backend.executor import DatabaseManager @@ -117,7 +120,7 @@ def execute_node( db_client: "DatabaseManager", creds_manager: IntegrationCredentialsManager, data: NodeExecutionEntry, - execution_stats: dict[str, Any] | None = None, + execution_stats: NodeExecutionStats | None = None, ) -> ExecutionStream: """ Execute a node in the graph. This will trigger a block execution on a node, @@ -126,7 +129,6 @@ def execute_node( Args: db_client: The client to send execution updates to the server. creds_manager: The manager to acquire and release credentials. - notification_service: The service to send notifications. data: The execution data for executing the current node. execution_stats: The execution statistics to be updated. @@ -256,10 +258,12 @@ def execute_node( # Update execution stats if execution_stats is not None: - execution_stats.update(node_block.execution_stats) - execution_stats["input_size"] = input_size - execution_stats["output_size"] = output_size - execution_stats["cost"] = cost + execution_stats = execution_stats.model_copy( + update=node_block.execution_stats.model_dump() + ) + execution_stats.input_size = input_size + execution_stats.output_size = output_size + execution_stats.cost = cost def _enqueue_next_nodes( @@ -476,7 +480,6 @@ class Executor: cls.pid = os.getpid() cls.db_client = get_db_client() cls.creds_manager = IntegrationCredentialsManager() - cls.notification_service = get_notification_service() # Set up shutdown handlers cls.shutdown_lock = threading.Lock() @@ -517,7 +520,7 @@ class Executor: cls, q: ExecutionQueue[NodeExecutionEntry], node_exec: NodeExecutionEntry, - ) -> dict[str, Any]: + ) -> NodeExecutionStats: log_metadata = LogMetadata( user_id=node_exec.user_id, graph_eid=node_exec.graph_exec_id, @@ -527,13 +530,15 @@ class Executor: block_name="-", ) - execution_stats = {} + execution_stats = NodeExecutionStats() timing_info, _ = cls._on_node_execution( q, node_exec, log_metadata, execution_stats ) - execution_stats["walltime"] = timing_info.wall_time - execution_stats["cputime"] = timing_info.cpu_time + execution_stats.walltime = timing_info.wall_time + execution_stats.cputime = timing_info.cpu_time + if isinstance(execution_stats.error, Exception): + execution_stats.error = str(execution_stats.error) cls.db_client.update_node_execution_stats( node_exec.node_exec_id, execution_stats ) @@ -546,7 +551,7 @@ class Executor: q: ExecutionQueue[NodeExecutionEntry], node_exec: NodeExecutionEntry, log_metadata: LogMetadata, - stats: dict[str, Any] | None = None, + stats: NodeExecutionStats | None = None, ): try: log_metadata.info(f"Start node execution {node_exec.node_exec_id}") @@ -569,6 +574,9 @@ class Executor: f"Failed node execution {node_exec.node_exec_id}: {e}" ) + if stats is not None: + stats.error = e + @classmethod def on_graph_executor_start(cls): configure_logging() @@ -577,6 +585,7 @@ class Executor: cls.db_client = get_db_client() cls.pool_size = settings.config.num_node_workers cls.pid = os.getpid() + cls.notification_service = get_notification_service() cls._init_node_executor_pool() logger.info( f"Graph executor {cls.pid} started with {cls.pool_size} node workers" @@ -618,9 +627,12 @@ class Executor: timing_info, (exec_stats, status, error) = cls._on_graph_execution( graph_exec, cancel, log_metadata ) - exec_stats["walltime"] = timing_info.wall_time - exec_stats["cputime"] = timing_info.cpu_time - exec_stats["error"] = str(error) if error else None + exec_stats.walltime = timing_info.wall_time + exec_stats.cputime = timing_info.cpu_time + exec_stats.error = error + + if isinstance(exec_stats.error, Exception): + exec_stats.error = str(exec_stats.error) result = cls.db_client.update_graph_execution_stats( graph_exec_id=graph_exec.graph_exec_id, status=status, @@ -628,48 +640,7 @@ class Executor: ) cls.db_client.send_execution_update(result) - metadata = cls.db_client.get_graph_metadata( - graph_exec.graph_id, graph_exec.graph_version - ) - assert metadata is not None - outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id) - - # Collect named outputs as a list of dictionaries - named_outputs = [] - for output in outputs: - if output.block_id == AgentOutputBlock().id: - # Create a dictionary for this named output - named_output = { - # Include the name as a field in each output - "name": ( - output.output_data["name"][0] - if isinstance(output.output_data["name"], list) - else output.output_data["name"] - ) - } - - # Add all other fields - for key, value in output.output_data.items(): - if key != "name": - named_output[key] = value - - named_outputs.append(named_output) - - event = NotificationEventDTO( - user_id=graph_exec.user_id, - type=NotificationType.AGENT_RUN, - data=AgentRunData( - outputs=named_outputs, - agent_name=metadata.name, - credits_used=exec_stats["cost"], - execution_time=timing_info.wall_time, - graph_id=graph_exec.graph_id, - node_count=exec_stats["node_count"], - ).model_dump(), - ) - - logger.info(f"Sending notification for {event}") - get_notification_service().queue_notification(event) + cls._handle_agent_run_notif(graph_exec, exec_stats) @classmethod @time_measured @@ -678,7 +649,7 @@ class Executor: graph_exec: GraphExecutionEntry, cancel: threading.Event, log_metadata: LogMetadata, - ) -> tuple[dict[str, Any], ExecutionStatus, Exception | None]: + ) -> tuple[GraphExecutionStats, ExecutionStatus, Exception | None]: """ Returns: dict: The execution statistics of the graph execution. @@ -686,12 +657,7 @@ class Executor: Exception | None: The error that occurred during the execution, if any. """ log_metadata.info(f"Start graph execution {graph_exec.graph_exec_id}") - exec_stats = { - "nodes_walltime": 0, - "nodes_cputime": 0, - "node_count": 0, - "cost": 0, - } + exec_stats = GraphExecutionStats() error = None finished = False @@ -717,18 +683,26 @@ class Executor: queue.add(node_exec) running_executions: dict[str, AsyncResult] = {} + low_balance_error: Optional[InsufficientBalanceError] = None def make_exec_callback(exec_data: NodeExecutionEntry): - node_id = exec_data.node_id def callback(result: object): - running_executions.pop(node_id) - nonlocal exec_stats - if isinstance(result, dict): - exec_stats["node_count"] += 1 - exec_stats["nodes_cputime"] += result.get("cputime", 0) - exec_stats["nodes_walltime"] += result.get("walltime", 0) - exec_stats["cost"] += result.get("cost", 0) + running_executions.pop(exec_data.node_id) + + if not isinstance(result, NodeExecutionStats): + return + + nonlocal exec_stats, low_balance_error + exec_stats.node_count += 1 + exec_stats.nodes_cputime += result.cputime + exec_stats.nodes_walltime += result.walltime + exec_stats.cost += result.cost + if (err := result.error) and isinstance(err, Exception): + exec_stats.node_error_count += 1 + + if isinstance(err, InsufficientBalanceError): + low_balance_error = err return callback @@ -773,6 +747,16 @@ class Executor: execution.wait(3) log_metadata.info(f"Finished graph execution {graph_exec.graph_exec_id}") + + if isinstance(low_balance_error, InsufficientBalanceError): + cls._handle_low_balance_notif( + graph_exec.user_id, + graph_exec.graph_id, + exec_stats, + low_balance_error, + ) + raise low_balance_error + except Exception as e: log_metadata.exception( f"Failed graph execution {graph_exec.graph_exec_id}: {e}" @@ -791,6 +775,67 @@ class Executor: error, ) + @classmethod + def _handle_agent_run_notif( + cls, + graph_exec: GraphExecutionEntry, + exec_stats: GraphExecutionStats, + ): + metadata = cls.db_client.get_graph_metadata( + graph_exec.graph_id, graph_exec.graph_version + ) + outputs = cls.db_client.get_execution_results(graph_exec.graph_exec_id) + + named_outputs = [ + { + key: value[0] if key == "name" else value + for key, value in output.output_data.items() + } + for output in outputs + if output.block_id == AgentOutputBlock().id + ] + + event = NotificationEventDTO( + user_id=graph_exec.user_id, + type=NotificationType.AGENT_RUN, + data=AgentRunData( + outputs=named_outputs, + agent_name=metadata.name if metadata else "Unknown Agent", + credits_used=exec_stats.cost, + execution_time=exec_stats.walltime, + graph_id=graph_exec.graph_id, + node_count=exec_stats.node_count, + ).model_dump(), + ) + + cls.notification_service.queue_notification(event) + + @classmethod + def _handle_low_balance_notif( + cls, + user_id: str, + graph_id: str, + exec_stats: GraphExecutionStats, + e: InsufficientBalanceError, + ): + shortfall = e.balance - e.amount + metadata = cls.db_client.get_graph_metadata(graph_id) + base_url = ( + settings.config.frontend_base_url or settings.config.platform_base_url + ) + cls.notification_service.queue_notification( + NotificationEventDTO( + user_id=user_id, + type=NotificationType.LOW_BALANCE, + data=LowBalanceData( + current_balance=exec_stats.cost, + billing_page_link=f"{base_url}/profile/credits", + shortfall=shortfall, + agent_name=metadata.name if metadata else "Unknown Agent", + ).model_dump(), + ) + ) + class ExecutionManager(AppService): def __init__(self): diff --git a/autogpt_platform/backend/backend/notifications/templates/low_balance.html.jinja2 b/autogpt_platform/backend/backend/notifications/templates/low_balance.html.jinja2 new file mode 100644 index 0000000000..d39d17cb14 --- /dev/null +++ b/autogpt_platform/backend/backend/notifications/templates/low_balance.html.jinja2 @@ -0,0 +1,114 @@ +{# Low Balance Notification Email Template #} +{# Template variables: +data.agent_name: the name of the agent +data.current_balance: the current balance of the user +data.billing_page_link: the link to the billing page +data.shortfall: the shortfall amount +#} + +

+ Low Balance Warning +

+ +

+ Your agent "{{ data.agent_name }}" has been stopped due to low balance. +

+ +
+

+ Current Balance: ${{ "{:.2f}".format((data.current_balance|float)/100) }} +

+

+ Shortfall: ${{ "{:.2f}".format((data.shortfall|float)/100) }} +

+
+ + +
+

+ Low Balance: +

+

+ Your agent "{{ data.agent_name }}" requires additional credits to continue running. The current operation has been canceled until your balance is replenished. +

+
+ +
+ + Manage Billing + +
+ +

+ This is an automated notification. Your agent is stopped and will need manually restarted unless set to trigger automatically. +

diff --git a/autogpt_platform/backend/backend/util/exceptions.py b/autogpt_platform/backend/backend/util/exceptions.py index 4bb3a08d95..eb6e42bda7 100644 --- a/autogpt_platform/backend/backend/util/exceptions.py +++ b/autogpt_platform/backend/backend/util/exceptions.py @@ -4,3 +4,22 @@ class MissingConfigError(Exception): class NeedConfirmation(Exception): """The user must explicitly confirm that they want to proceed""" + + +class InsufficientBalanceError(ValueError): + user_id: str + message: str + balance: float + amount: float + + def __init__(self, message: str, user_id: str, balance: float, amount: float): + super().__init__(message) + self.args = (message, user_id, balance, amount) + self.message = message + self.user_id = user_id + self.balance = balance + self.amount = amount + + def __str__(self): + """Used to display the error message in the frontend, because we str() the error when sending the execution update""" + return self.message diff --git a/autogpt_platform/backend/backend/util/service.py b/autogpt_platform/backend/backend/util/service.py index e29a9a5d19..716d1b13f2 100644 --- a/autogpt_platform/backend/backend/util/service.py +++ b/autogpt_platform/backend/backend/util/service.py @@ -42,6 +42,7 @@ from Pyro5 import api as pyro from Pyro5 import config as pyro_config from backend.data import db, rabbitmq, redis +from backend.util.exceptions import InsufficientBalanceError from backend.util.json import to_dict from backend.util.process import AppProcess from backend.util.retry import conn_retry @@ -251,6 +252,7 @@ EXCEPTION_MAPPING = { ValueError, TimeoutError, ConnectionError, + InsufficientBalanceError, ] } @@ -441,6 +443,7 @@ def fastapi_get_service_client(service_type: Type[AS]) -> AS: except httpx.HTTPStatusError as e: logger.error(f"HTTP error in {method_name}: {e.response.text}") error = RemoteCallError.model_validate(e.response.json(), strict=False) + # DEBUG HELP: if you made a custom exception, make sure you override self.args to be how to make your exception raise EXCEPTION_MAPPING.get(error.type, Exception)( *(error.args or [str(e)]) ) diff --git a/autogpt_platform/backend/backend/util/text.py b/autogpt_platform/backend/backend/util/text.py index e118325f7e..0c1e22edc0 100644 --- a/autogpt_platform/backend/backend/util/text.py +++ b/autogpt_platform/backend/backend/util/text.py @@ -21,46 +21,83 @@ class TextFormatter: self.env.globals.clear() # Instead of clearing all filters, just remove potentially unsafe ones - unsafe_filters = ["pprint", "urlize", "xmlattr", "tojson"] + unsafe_filters = ["pprint", "tojson", "urlize", "xmlattr"] for f in unsafe_filters: if f in self.env.filters: del self.env.filters[f] self.env.filters["format"] = format_filter_for_jinja2 - # Define allowed CSS properties + # Define allowed CSS properties (sorted alphabetically, if you add more) allowed_css_properties = [ - "font-family", + "background-color", + "border", + "border-bottom", + "border-color", + "border-left", + "border-radius", + "border-right", + "border-style", + "border-top", + "border-width", + "bottom", + "box-shadow", + "clear", "color", + "display", + "float", + "font-family", "font-size", + "font-weight", + "height", + "left", + "letter-spacing", "line-height", - "margin-top", "margin-bottom", "margin-left", "margin-right", - "background-color", + "margin-top", "padding", - "border-radius", - "font-weight", + "position", + "right", "text-align", + "text-decoration", + "text-shadow", + "text-transform", + "top", + "width", ] self.css_sanitizer = CSSSanitizer(allowed_css_properties=allowed_css_properties) + # Define allowed tags (sorted alphabetically, if you add more) self.allowed_tags = [ - "p", + "a", "b", + "br", + "div", + "em", + "h1", + "h2", + "h3", + "h4", + "h5", "i", + "img", + "li", + "p", + "span", + "strong", "u", "ul", - "li", - "br", - "strong", - "em", - "div", - "span", ] - self.allowed_attributes = {"*": ["style", "class"]} + + # Define allowed attributes to be used on specific tags + self.allowed_attributes = { + "*": ["class", "style"], + "a": ["href"], + "img": ["src"], + } def format_string(self, template_str: str, values=None, **kwargs) -> str: """Regular template rendering with escaping"""