refactor(backend): Clear out Notification Service code blockage (#9915)

Some of the code paths in the notification & scheduler service were
synchronous HTTP calls that execute a long-running job that blocks. This
makes the service threads busy waiting.

### Changes 🏗️

* Remove queue_notification API
* Remove DTO
* Move heavy tasks intothe  executor

<!-- Concisely describe all of the changes made in this pull request:
-->

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Manually executing notification service jobs through the scheduler
API
This commit is contained in:
Zamil Majdy
2025-05-09 13:16:43 +07:00
committed by GitHub
parent 089e7aae88
commit 82cf0bcde7
4 changed files with 160 additions and 172 deletions

View File

@@ -5,7 +5,6 @@ from datetime import datetime, timezone
from typing import Any, cast
import stripe
from autogpt_libs.utils.cache import thread_cached
from prisma import Json
from prisma.enums import (
CreditRefundRequestStatus,
@@ -32,14 +31,13 @@ from backend.data.model import (
TransactionHistory,
UserTransaction,
)
from backend.data.notifications import NotificationEventDTO, RefundRequestData
from backend.data.notifications import NotificationEventModel, RefundRequestData
from backend.data.user import get_user_by_id, get_user_email_by_id
from backend.notifications import NotificationManagerClient
from backend.notifications.notifications import queue_notification_async
from backend.server.model import Pagination
from backend.server.v2.admin.model import UserHistoryResponse
from backend.util.exceptions import InsufficientBalanceError
from backend.util.retry import func_retry
from backend.util.service import get_service_client
from backend.util.settings import Settings
settings = Settings()
@@ -374,20 +372,17 @@ class UserCreditBase(ABC):
class UserCredit(UserCreditBase):
@thread_cached
def notification_client(self) -> NotificationManagerClient:
return get_service_client(NotificationManagerClient)
async def _send_refund_notification(
self,
notification_request: RefundRequestData,
notification_type: NotificationType,
):
await self.notification_client().queue_notification_async(
NotificationEventDTO(
await queue_notification_async(
NotificationEventModel(
user_id=notification_request.user_id,
type=notification_type,
data=notification_request.model_dump(),
data=notification_request,
)
)

View File

@@ -189,26 +189,14 @@ NotificationData = Annotated[
]
class NotificationEventDTO(BaseModel):
user_id: str
class BaseEventModel(BaseModel):
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
retry_count: int = 0
class SummaryParamsEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
user_id: str
type: NotificationType
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
data: NotificationDataType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
@property
def strategy(self) -> QueueType:
@@ -225,11 +213,8 @@ class NotificationEventModel(BaseModel, Generic[NotificationDataType_co]):
return NotificationTypeOverride(self.type).template
class SummaryParamsEventModel(BaseModel, Generic[SummaryParamsType_co]):
user_id: str
type: NotificationType
class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]):
data: SummaryParamsType_co
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
def get_notif_data_type(

View File

@@ -23,16 +23,16 @@ from backend.data.model import (
from backend.data.notifications import (
AgentRunData,
LowBalanceData,
NotificationEventDTO,
NotificationEventModel,
NotificationType,
)
from backend.data.rabbitmq import SyncRabbitMQ
from backend.executor.utils import create_execution_queue_config
from backend.notifications.notifications import queue_notification
from backend.util.exceptions import InsufficientBalanceError
if TYPE_CHECKING:
from backend.executor import DatabaseManagerClient
from backend.notifications.notifications import NotificationManagerClient
from autogpt_libs.utils.cache import thread_cached
from prometheus_client import Gauge, start_http_server
@@ -580,7 +580,6 @@ class Executor:
cls.db_client = get_db_client()
cls.pool_size = settings.config.num_node_workers
cls.pid = os.getpid()
cls.notification_service = get_notification_service()
cls._init_node_executor_pool()
logger.info(f"GraphExec {cls.pid} started with {cls.pool_size} node workers")
@@ -905,21 +904,21 @@ class Executor:
for output in outputs
]
event = NotificationEventDTO(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
).model_dump(),
queue_notification(
NotificationEventModel(
user_id=graph_exec.user_id,
type=NotificationType.AGENT_RUN,
data=AgentRunData(
outputs=named_outputs,
agent_name=metadata.name if metadata else "Unknown Agent",
credits_used=exec_stats.cost,
execution_time=exec_stats.walltime,
graph_id=graph_exec.graph_id,
node_count=exec_stats.node_count,
),
)
)
cls.notification_service.queue_notification(event)
@classmethod
def _handle_low_balance_notif(
cls,
@@ -933,8 +932,8 @@ class Executor:
base_url = (
settings.config.frontend_base_url or settings.config.platform_base_url
)
cls.notification_service.queue_notification(
NotificationEventDTO(
queue_notification(
NotificationEventModel(
user_id=user_id,
type=NotificationType.LOW_BALANCE,
data=LowBalanceData(
@@ -942,7 +941,7 @@ class Executor:
billing_page_link=f"{base_url}/profile/credits",
shortfall=shortfall,
agent_name=metadata.name if metadata else "Unknown Agent",
).model_dump(),
),
)
)
@@ -1139,14 +1138,6 @@ def get_db_client() -> "DatabaseManagerClient":
return get_service_client(DatabaseManagerClient, health_check=False)
@thread_cached
def get_notification_service() -> "NotificationManagerClient":
from backend.notifications import NotificationManagerClient
# Disable health check for the service client to avoid breaking process initializer.
return get_service_client(NotificationManagerClient, health_check=False)
def send_execution_update(entry: GraphExecution | NodeExecutionResult | None):
if entry is None:
return

View File

@@ -1,5 +1,6 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta, timezone
from typing import Callable
@@ -7,20 +8,18 @@ import aio_pika
from aio_pika.exceptions import QueueEmpty
from autogpt_libs.utils.cache import thread_cached
from prisma.enums import NotificationType
from pydantic import BaseModel
from backend.data import rabbitmq
from backend.data.notifications import (
BaseEventModel,
BaseSummaryData,
BaseSummaryParams,
DailySummaryData,
DailySummaryParams,
NotificationEventDTO,
NotificationEventModel,
NotificationResult,
NotificationTypeOverride,
QueueType,
SummaryParamsEventDTO,
SummaryParamsEventModel,
WeeklySummaryData,
WeeklySummaryParams,
@@ -28,13 +27,19 @@ from backend.data.notifications import (
get_notif_data_type,
get_summary_params_type,
)
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
SyncRabbitMQ,
)
from backend.data.user import generate_unsubscribe_link
from backend.notifications.email import EmailSender
from backend.util.service import (
AppService,
AppServiceClient,
endpoint_to_async,
expose,
get_service_client,
)
@@ -44,70 +49,66 @@ logger = logging.getLogger(__name__)
settings = Settings()
class NotificationEvent(BaseModel):
event: NotificationEventDTO
model: NotificationEventModel
NOTIFICATION_EXCHANGE = Exchange(name="notifications", type=ExchangeType.TOPIC)
DEAD_LETTER_EXCHANGE = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
EXCHANGES = [NOTIFICATION_EXCHANGE, DEAD_LETTER_EXCHANGE]
background_executor = ThreadPoolExecutor(max_workers=2)
def create_notification_config() -> RabbitMQConfig:
"""Create RabbitMQ configuration for notifications"""
notification_exchange = Exchange(name="notifications", type=ExchangeType.TOPIC)
dead_letter_exchange = Exchange(name="dead_letter", type=ExchangeType.TOPIC)
queues = [
# Main notification queues
Queue(
name="immediate_notifications",
exchange=notification_exchange,
exchange=NOTIFICATION_EXCHANGE,
routing_key="notification.immediate.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-routing-key": "failed.immediate",
},
),
Queue(
name="admin_notifications",
exchange=notification_exchange,
exchange=NOTIFICATION_EXCHANGE,
routing_key="notification.admin.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-routing-key": "failed.admin",
},
),
# Summary notification queues
Queue(
name="summary_notifications",
exchange=notification_exchange,
exchange=NOTIFICATION_EXCHANGE,
routing_key="notification.summary.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-routing-key": "failed.summary",
},
),
# Batch Queue
Queue(
name="batch_notifications",
exchange=notification_exchange,
exchange=NOTIFICATION_EXCHANGE,
routing_key="notification.batch.#",
arguments={
"x-dead-letter-exchange": dead_letter_exchange.name,
"x-dead-letter-exchange": DEAD_LETTER_EXCHANGE.name,
"x-dead-letter-routing-key": "failed.batch",
},
),
# Failed notifications queue
Queue(
name="failed_notifications",
exchange=dead_letter_exchange,
exchange=DEAD_LETTER_EXCHANGE,
routing_key="failed.#",
),
]
return RabbitMQConfig(
exchanges=[
notification_exchange,
dead_letter_exchange,
],
exchanges=EXCHANGES,
queues=queues,
)
@@ -119,6 +120,86 @@ def get_db():
return get_service_client(DatabaseManagerClient)
@thread_cached
def get_notification_queue() -> SyncRabbitMQ:
client = SyncRabbitMQ(create_notification_config())
client.connect()
return client
@thread_cached
async def get_async_notification_queue() -> AsyncRabbitMQ:
client = AsyncRabbitMQ(create_notification_config())
await client.connect()
return client
def get_routing_key(event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
def queue_notification(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.debug(f"Received Request to queue {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
queue = get_notification_queue()
queue.publish_message(
routing_key=routing_key,
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
async def queue_notification_async(event: NotificationEventModel) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.debug(f"Received Request to queue {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
queue = await get_async_notification_queue()
await queue.publish_message(
routing_key=routing_key,
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
class NotificationManager(AppService):
"""Service for handling notifications with batching support"""
@@ -146,23 +227,11 @@ class NotificationManager(AppService):
def get_port(cls) -> int:
return settings.config.notification_service_port
def get_routing_key(self, event_type: NotificationType) -> str:
strategy = NotificationTypeOverride(event_type).strategy
"""Get the appropriate routing key for an event"""
if strategy == QueueType.IMMEDIATE:
return f"notification.immediate.{event_type.value}"
elif strategy == QueueType.BACKOFF:
return f"notification.backoff.{event_type.value}"
elif strategy == QueueType.ADMIN:
return f"notification.admin.{event_type.value}"
elif strategy == QueueType.BATCH:
return f"notification.batch.{event_type.value}"
elif strategy == QueueType.SUMMARY:
return f"notification.summary.{event_type.value}"
return f"notification.{event_type.value}"
@expose
def queue_weekly_summary(self):
background_executor.submit(self._queue_weekly_summary)
def _queue_weekly_summary(self):
"""Process weekly summary for specified notification types"""
try:
logger.info("Processing weekly summary queuing operation")
@@ -176,13 +245,13 @@ class NotificationManager(AppService):
for user in users:
self._queue_scheduled_notification(
SummaryParamsEventDTO(
SummaryParamsEventModel(
user_id=user,
type=NotificationType.WEEKLY_SUMMARY,
data=WeeklySummaryParams(
start_date=start_time,
end_date=current_time,
).model_dump(),
),
),
)
processed_count += 1
@@ -194,6 +263,9 @@ class NotificationManager(AppService):
@expose
def process_existing_batches(self, notification_types: list[NotificationType]):
background_executor.submit(self._process_existing_batches, notification_types)
def _process_existing_batches(self, notification_types: list[NotificationType]):
"""Process existing batches for specified notification types"""
try:
processed_count = 0
@@ -312,66 +384,20 @@ class NotificationManager(AppService):
"timestamp": datetime.now(tz=timezone.utc).isoformat(),
}
@expose
def queue_notification(self, event: NotificationEventDTO) -> NotificationResult:
"""Queue a notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue {event=}")
# Workaround for not being able to serialize generics over the expose bus
parsed_event = NotificationEventModel[
get_notif_data_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(parsed_event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
exchange = "notifications"
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
)
)
return NotificationResult(
success=True,
message=f"Notification queued with routing key: {routing_key}",
)
except Exception as e:
logger.exception(f"Error queueing notification: {e}")
return NotificationResult(success=False, message=str(e))
def _queue_scheduled_notification(self, event: SummaryParamsEventDTO):
def _queue_scheduled_notification(self, event: SummaryParamsEventModel):
"""Queue a scheduled notification - exposed method for other services to call"""
try:
logger.info(f"Received Request to queue scheduled notification {event=}")
parsed_event = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate(event.model_dump())
routing_key = self.get_routing_key(event.type)
message = parsed_event.model_dump_json()
logger.info(f"Received Request to queue {message=}")
logger.debug(f"Received Request to queue scheduled notification {event=}")
exchange = "notifications"
routing_key = get_routing_key(event.type)
# Publish to RabbitMQ
self.run_and_wait(
self.rabbit.publish_message(
routing_key=routing_key,
message=message,
exchange=next(
ex for ex in self.rabbit_config.exchanges if ex.name == exchange
),
message=event.model_dump_json(),
exchange=next(ex for ex in EXCHANGES if ex.name == exchange),
)
)
@@ -497,13 +523,12 @@ class NotificationManager(AppService):
)
return False
def _parse_message(self, message: str) -> NotificationEvent | None:
def _parse_message(self, message: str) -> NotificationEventModel | None:
try:
event = NotificationEventDTO.model_validate_json(message)
model = NotificationEventModel[
event = BaseEventModel.model_validate_json(message)
return NotificationEventModel[
get_notif_data_type(event.type)
].model_validate_json(message)
return NotificationEvent(event=event, model=model)
except Exception as e:
logger.error(f"Error parsing message due to non matching schema {e}")
return None
@@ -511,14 +536,12 @@ class NotificationManager(AppService):
def _process_admin_message(self, message: str) -> bool:
"""Process a single notification, sending to an admin, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
event = self._parse_message(message)
if not event:
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing notification for admin: {model}")
logger.debug(f"Processing notification for admin: {event}")
recipient_email = settings.config.refund_notification_email
self.email_sender.send_templated(event.type, recipient_email, model)
self.email_sender.send_templated(event.type, recipient_email, event)
return True
except Exception as e:
logger.exception(f"Error processing notification for admin queue: {e}")
@@ -527,12 +550,10 @@ class NotificationManager(AppService):
def _process_immediate(self, message: str) -> bool:
"""Process a single notification immediately, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
event = self._parse_message(message)
if not event:
return False
event = parsed.event
model = parsed.model
logger.debug(f"Processing immediate notification: {model}")
logger.debug(f"Processing immediate notification: {event}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
@@ -553,7 +574,7 @@ class NotificationManager(AppService):
self.email_sender.send_templated(
notification=event.type,
user_email=recipient_email,
data=model,
data=event,
user_unsub_link=unsub_link,
)
return True
@@ -564,12 +585,10 @@ class NotificationManager(AppService):
def _process_batch(self, message: str) -> bool:
"""Process a single notification with a batching strategy, returning whether to put into the failed queue"""
try:
parsed = self._parse_message(message)
if not parsed:
event = self._parse_message(message)
if not event:
return False
event = parsed.event
model = parsed.model
logger.info(f"Processing batch notification: {model}")
logger.info(f"Processing batch notification: {event}")
recipient_email = get_db().get_user_email_by_id(event.user_id)
if not recipient_email:
@@ -585,7 +604,7 @@ class NotificationManager(AppService):
)
return True
should_send = self._should_batch(event.user_id, event.type, model)
should_send = self._should_batch(event.user_id, event.type, event)
if not should_send:
logger.info("Batch not old enough to send")
@@ -627,7 +646,7 @@ class NotificationManager(AppService):
"""Process a single notification with a summary strategy, returning whether to put into the failed queue"""
try:
logger.info(f"Processing summary notification: {message}")
event = SummaryParamsEventDTO.model_validate_json(message)
event = BaseEventModel.model_validate_json(message)
model = SummaryParamsEventModel[
get_summary_params_type(event.type)
].model_validate_json(message)
@@ -764,7 +783,5 @@ class NotificationManagerClient(AppServiceClient):
def get_service_type(cls):
return NotificationManager
queue_notification_async = endpoint_to_async(NotificationManager.queue_notification)
queue_notification = NotificationManager.queue_notification
process_existing_batches = NotificationManager.process_existing_batches
queue_weekly_summary = NotificationManager.queue_weekly_summary