mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
## Summary Move DatabaseError from store-specific exceptions to generic backend exceptions for proper layer separation, while also fixing store exception inheritance to ensure proper HTTP status codes. ## Problem 1. **Poor Layer Separation**: DatabaseError was defined in store-specific exceptions but represents infrastructure concerns that affect the entire backend 2. **Incorrect HTTP Status Codes**: Store exceptions inherited from Exception instead of ValueError, causing 500 responses for client errors 3. **Reusability Issues**: Other backend modules couldn't use DatabaseError for DB operations 4. **Blanket Catch Issues**: Store-specific catches were affecting generic database operations ## Solution ### Move DatabaseError to Generic Location - Move from backend.server.v2.store.exceptions to backend.util.exceptions - Update all 23 references in backend/server/v2/store/db.py to use new location - Remove from StoreError inheritance hierarchy ### Fix Complete Store Exception Hierarchy - MediaUploadError: Changed from Exception to ValueError inheritance (client errors → 400) - StoreError: Changed from Exception to ValueError inheritance (business logic errors → 400) - Store NotFound exceptions: Changed to inherit from NotFoundError (→ 404) - DatabaseError: Now properly inherits from Exception (infrastructure errors → 500) ## Benefits ### ✅ Proper Layer Separation - Database errors are infrastructure concerns, not store-specific business logic - Store exceptions focus on business validation and client errors - Clean separation between infrastructure and business logic layers ### ✅ Correct HTTP Status Codes - DatabaseError: 500 (server infrastructure errors) - Store NotFound errors: 404 (via existing NotFoundError handler) - Store validation errors: 400 (via existing ValueError handler) - Media upload errors: 400 (client validation errors) ### ✅ Architectural Improvements - DatabaseError now reusable across entire backend - Eliminates blanket catch issues affecting generic DB operations - All store exceptions use global exception handlers properly - Future store exceptions automatically get proper status codes ## Files Changed - **backend/util/exceptions.py**: Add DatabaseError class - **backend/server/v2/store/exceptions.py**: Remove DatabaseError, fix inheritance hierarchy - **backend/server/v2/store/db.py**: Update all DatabaseError references to new location ## Result - ✅ **No more stack trace spam**: Expected business logic errors handled properly - ✅ **Proper HTTP semantics**: 500 for infrastructure, 400/404 for client errors - ✅ **Better architecture**: Clean layer separation and reusable components - ✅ **Fixes original issue**: AgentNotFoundError now returns 404 instead of 500 This addresses the logging issue mentioned by @zamilmajdy while also implementing the architectural improvements suggested by @Pwuts.
654 lines
22 KiB
Python
654 lines
22 KiB
Python
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
from enum import Enum
|
|
from typing import Annotated, Any, Generic, Optional, TypeVar, Union
|
|
|
|
from prisma import Json
|
|
from prisma.enums import NotificationType
|
|
from prisma.models import NotificationEvent, UserNotificationBatch
|
|
from prisma.types import (
|
|
NotificationEventCreateInput,
|
|
UserNotificationBatchCreateInput,
|
|
UserNotificationBatchWhereInput,
|
|
)
|
|
|
|
# from backend.notifications.models import NotificationEvent
|
|
from pydantic import BaseModel, ConfigDict, EmailStr, Field, field_validator
|
|
|
|
from backend.util.exceptions import DatabaseError
|
|
from backend.util.json import SafeJson
|
|
from backend.util.logging import TruncatedLogger
|
|
|
|
from .db import transaction
|
|
|
|
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[NotificationService]")
|
|
|
|
|
|
NotificationDataType_co = TypeVar(
|
|
"NotificationDataType_co", bound="BaseNotificationData", covariant=True
|
|
)
|
|
SummaryParamsType_co = TypeVar(
|
|
"SummaryParamsType_co", bound="BaseSummaryParams", covariant=True
|
|
)
|
|
|
|
|
|
class QueueType(Enum):
|
|
IMMEDIATE = "immediate" # Send right away (errors, critical notifications)
|
|
BATCH = "batch" # Batch for up to an hour (usage reports)
|
|
SUMMARY = "summary" # Daily digest (summary notifications)
|
|
BACKOFF = "backoff" # Backoff strategy (exponential backoff)
|
|
ADMIN = "admin" # Admin notifications (errors, critical notifications)
|
|
|
|
|
|
class BaseNotificationData(BaseModel):
|
|
model_config = ConfigDict(extra="allow")
|
|
|
|
|
|
class AgentRunData(BaseNotificationData):
|
|
agent_name: str
|
|
credits_used: float
|
|
execution_time: float
|
|
node_count: int = Field(..., description="Number of nodes executed")
|
|
graph_id: str
|
|
outputs: list[dict[str, Any]] = Field(..., description="Outputs of the agent")
|
|
|
|
|
|
class ZeroBalanceData(BaseNotificationData):
|
|
agent_name: str = Field(..., description="Name of the agent")
|
|
current_balance: float = Field(
|
|
..., description="Current balance in credits (100 = $1)"
|
|
)
|
|
billing_page_link: str = Field(..., description="Link to billing page")
|
|
shortfall: float = Field(..., description="Amount of credits needed to continue")
|
|
|
|
|
|
class LowBalanceData(BaseNotificationData):
|
|
current_balance: float = Field(
|
|
..., description="Current balance in credits (100 = $1)"
|
|
)
|
|
billing_page_link: str = Field(..., description="Link to billing page")
|
|
|
|
|
|
class BlockExecutionFailedData(BaseNotificationData):
|
|
block_name: str
|
|
block_id: str
|
|
error_message: str
|
|
graph_id: str
|
|
node_id: str
|
|
execution_id: str
|
|
|
|
|
|
class ContinuousAgentErrorData(BaseNotificationData):
|
|
agent_name: str
|
|
error_message: str
|
|
graph_id: str
|
|
execution_id: str
|
|
start_time: datetime
|
|
error_time: datetime
|
|
attempts: int = Field(..., description="Number of retry attempts made")
|
|
|
|
@field_validator("start_time", "error_time")
|
|
@classmethod
|
|
def validate_timezone(cls, value: datetime):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class BaseSummaryData(BaseNotificationData):
|
|
total_credits_used: float
|
|
total_executions: int
|
|
most_used_agent: str
|
|
total_execution_time: float
|
|
successful_runs: int
|
|
failed_runs: int
|
|
average_execution_time: float
|
|
cost_breakdown: dict[str, float]
|
|
|
|
|
|
class BaseSummaryParams(BaseModel):
|
|
start_date: datetime
|
|
end_date: datetime
|
|
|
|
@field_validator("start_date", "end_date")
|
|
def validate_timezone(cls, value):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class DailySummaryParams(BaseSummaryParams):
|
|
date: datetime
|
|
|
|
@field_validator("date")
|
|
def validate_timezone(cls, value):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class WeeklySummaryParams(BaseSummaryParams):
|
|
start_date: datetime
|
|
end_date: datetime
|
|
|
|
@field_validator("start_date", "end_date")
|
|
def validate_timezone(cls, value):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class DailySummaryData(BaseSummaryData):
|
|
date: datetime
|
|
|
|
@field_validator("date")
|
|
def validate_timezone(cls, value):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class WeeklySummaryData(BaseSummaryData):
|
|
start_date: datetime
|
|
end_date: datetime
|
|
|
|
@field_validator("start_date", "end_date")
|
|
def validate_timezone(cls, value):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class MonthlySummaryData(BaseNotificationData):
|
|
month: int
|
|
year: int
|
|
|
|
|
|
class RefundRequestData(BaseNotificationData):
|
|
user_id: str
|
|
user_name: str
|
|
user_email: str
|
|
transaction_id: str
|
|
refund_request_id: str
|
|
reason: str
|
|
amount: float
|
|
balance: int
|
|
|
|
|
|
class AgentApprovalData(BaseNotificationData):
|
|
agent_name: str
|
|
agent_id: str
|
|
agent_version: int
|
|
reviewer_name: str
|
|
reviewer_email: str
|
|
comments: str
|
|
reviewed_at: datetime
|
|
store_url: str
|
|
|
|
@field_validator("reviewed_at")
|
|
@classmethod
|
|
def validate_timezone(cls, value: datetime):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
class AgentRejectionData(BaseNotificationData):
|
|
agent_name: str
|
|
agent_id: str
|
|
agent_version: int
|
|
reviewer_name: str
|
|
reviewer_email: str
|
|
comments: str
|
|
reviewed_at: datetime
|
|
resubmit_url: str
|
|
|
|
@field_validator("reviewed_at")
|
|
@classmethod
|
|
def validate_timezone(cls, value: datetime):
|
|
if value.tzinfo is None:
|
|
raise ValueError("datetime must have timezone information")
|
|
return value
|
|
|
|
|
|
NotificationData = Annotated[
|
|
Union[
|
|
AgentRunData,
|
|
ZeroBalanceData,
|
|
LowBalanceData,
|
|
BlockExecutionFailedData,
|
|
ContinuousAgentErrorData,
|
|
MonthlySummaryData,
|
|
WeeklySummaryData,
|
|
DailySummaryData,
|
|
RefundRequestData,
|
|
BaseSummaryData,
|
|
],
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class BaseEventModel(BaseModel):
|
|
type: NotificationType
|
|
user_id: str
|
|
created_at: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
|
|
|
|
class NotificationEventModel(BaseEventModel, Generic[NotificationDataType_co]):
|
|
id: Optional[str] = None # None when creating, populated when reading from DB
|
|
data: NotificationDataType_co
|
|
|
|
@property
|
|
def strategy(self) -> QueueType:
|
|
return NotificationTypeOverride(self.type).strategy
|
|
|
|
@field_validator("type", mode="before")
|
|
def uppercase_type(cls, v):
|
|
if isinstance(v, str):
|
|
return v.upper()
|
|
return v
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
return NotificationTypeOverride(self.type).template
|
|
|
|
|
|
class SummaryParamsEventModel(BaseEventModel, Generic[SummaryParamsType_co]):
|
|
data: SummaryParamsType_co
|
|
|
|
|
|
def get_notif_data_type(
|
|
notification_type: NotificationType,
|
|
) -> type[BaseNotificationData]:
|
|
return {
|
|
NotificationType.AGENT_RUN: AgentRunData,
|
|
NotificationType.ZERO_BALANCE: ZeroBalanceData,
|
|
NotificationType.LOW_BALANCE: LowBalanceData,
|
|
NotificationType.BLOCK_EXECUTION_FAILED: BlockExecutionFailedData,
|
|
NotificationType.CONTINUOUS_AGENT_ERROR: ContinuousAgentErrorData,
|
|
NotificationType.DAILY_SUMMARY: DailySummaryData,
|
|
NotificationType.WEEKLY_SUMMARY: WeeklySummaryData,
|
|
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
|
NotificationType.REFUND_REQUEST: RefundRequestData,
|
|
NotificationType.REFUND_PROCESSED: RefundRequestData,
|
|
NotificationType.AGENT_APPROVED: AgentApprovalData,
|
|
NotificationType.AGENT_REJECTED: AgentRejectionData,
|
|
}[notification_type]
|
|
|
|
|
|
def get_summary_params_type(
|
|
notification_type: NotificationType,
|
|
) -> type[BaseSummaryParams]:
|
|
return {
|
|
NotificationType.DAILY_SUMMARY: DailySummaryParams,
|
|
NotificationType.WEEKLY_SUMMARY: WeeklySummaryParams,
|
|
}[notification_type]
|
|
|
|
|
|
class NotificationBatch(BaseModel):
|
|
user_id: str
|
|
events: list[NotificationEvent]
|
|
strategy: QueueType
|
|
last_update: datetime = Field(default_factory=lambda: datetime.now(tz=timezone.utc))
|
|
|
|
|
|
class NotificationResult(BaseModel):
|
|
success: bool
|
|
message: Optional[str] = None
|
|
|
|
|
|
class NotificationTypeOverride:
|
|
def __init__(self, notification_type: NotificationType):
|
|
self.notification_type = notification_type
|
|
|
|
@property
|
|
def strategy(self) -> QueueType:
|
|
BATCHING_RULES = {
|
|
# These are batched by the notification service
|
|
NotificationType.AGENT_RUN: QueueType.BATCH,
|
|
# These are batched by the notification service, but with a backoff strategy
|
|
NotificationType.ZERO_BALANCE: QueueType.IMMEDIATE,
|
|
NotificationType.LOW_BALANCE: QueueType.IMMEDIATE,
|
|
NotificationType.BLOCK_EXECUTION_FAILED: QueueType.BACKOFF,
|
|
NotificationType.CONTINUOUS_AGENT_ERROR: QueueType.BACKOFF,
|
|
NotificationType.DAILY_SUMMARY: QueueType.SUMMARY,
|
|
NotificationType.WEEKLY_SUMMARY: QueueType.SUMMARY,
|
|
NotificationType.MONTHLY_SUMMARY: QueueType.SUMMARY,
|
|
NotificationType.REFUND_REQUEST: QueueType.ADMIN,
|
|
NotificationType.REFUND_PROCESSED: QueueType.ADMIN,
|
|
NotificationType.AGENT_APPROVED: QueueType.IMMEDIATE,
|
|
NotificationType.AGENT_REJECTED: QueueType.IMMEDIATE,
|
|
}
|
|
return BATCHING_RULES.get(self.notification_type, QueueType.IMMEDIATE)
|
|
|
|
@property
|
|
def template(self) -> str:
|
|
"""Returns template name for this notification type"""
|
|
return {
|
|
NotificationType.AGENT_RUN: "agent_run.html",
|
|
NotificationType.ZERO_BALANCE: "zero_balance.html",
|
|
NotificationType.LOW_BALANCE: "low_balance.html",
|
|
NotificationType.BLOCK_EXECUTION_FAILED: "block_failed.html",
|
|
NotificationType.CONTINUOUS_AGENT_ERROR: "agent_error.html",
|
|
NotificationType.DAILY_SUMMARY: "daily_summary.html",
|
|
NotificationType.WEEKLY_SUMMARY: "weekly_summary.html",
|
|
NotificationType.MONTHLY_SUMMARY: "monthly_summary.html",
|
|
NotificationType.REFUND_REQUEST: "refund_request.html",
|
|
NotificationType.REFUND_PROCESSED: "refund_processed.html",
|
|
NotificationType.AGENT_APPROVED: "agent_approved.html",
|
|
NotificationType.AGENT_REJECTED: "agent_rejected.html",
|
|
}[self.notification_type]
|
|
|
|
@property
|
|
def subject(self) -> str:
|
|
return {
|
|
NotificationType.AGENT_RUN: "Agent Run Report",
|
|
NotificationType.ZERO_BALANCE: "You're out of credits!",
|
|
NotificationType.LOW_BALANCE: "Low Balance Warning!",
|
|
NotificationType.BLOCK_EXECUTION_FAILED: "Uh oh! Block Execution Failed",
|
|
NotificationType.CONTINUOUS_AGENT_ERROR: "Shoot! Continuous Agent Error",
|
|
NotificationType.DAILY_SUMMARY: "Here's your daily summary!",
|
|
NotificationType.WEEKLY_SUMMARY: "Look at all the cool stuff you did last week!",
|
|
NotificationType.MONTHLY_SUMMARY: "We did a lot this month!",
|
|
NotificationType.REFUND_REQUEST: "[ACTION REQUIRED] You got a ${{data.amount / 100}} refund request from {{data.user_name}}",
|
|
NotificationType.REFUND_PROCESSED: "Refund for ${{data.amount / 100}} to {{data.user_name}} has been processed",
|
|
NotificationType.AGENT_APPROVED: "🎉 Your agent '{{data.agent_name}}' has been approved!",
|
|
NotificationType.AGENT_REJECTED: "Your agent '{{data.agent_name}}' needs some updates",
|
|
}[self.notification_type]
|
|
|
|
|
|
class NotificationPreferenceDTO(BaseModel):
|
|
email: EmailStr = Field(..., description="User's email address")
|
|
preferences: dict[NotificationType, bool] = Field(
|
|
..., description="Which notifications the user wants"
|
|
)
|
|
daily_limit: int = Field(..., description="Max emails per day")
|
|
|
|
|
|
class NotificationPreference(BaseModel):
|
|
user_id: str
|
|
email: EmailStr
|
|
preferences: dict[NotificationType, bool] = Field(
|
|
default_factory=dict, description="Which notifications the user wants"
|
|
)
|
|
daily_limit: int = 10 # Max emails per day
|
|
emails_sent_today: int = 0
|
|
last_reset_date: datetime = Field(
|
|
default_factory=lambda: datetime.now(timezone.utc)
|
|
)
|
|
|
|
|
|
class UserNotificationEventDTO(BaseModel):
|
|
id: str # Added to track notifications for removal
|
|
type: NotificationType
|
|
data: dict
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
@staticmethod
|
|
def from_db(model: NotificationEvent) -> "UserNotificationEventDTO":
|
|
return UserNotificationEventDTO(
|
|
id=model.id,
|
|
type=model.type,
|
|
data=dict(model.data),
|
|
created_at=model.createdAt,
|
|
updated_at=model.updatedAt,
|
|
)
|
|
|
|
|
|
class UserNotificationBatchDTO(BaseModel):
|
|
user_id: str
|
|
type: NotificationType
|
|
notifications: list[UserNotificationEventDTO]
|
|
created_at: datetime
|
|
updated_at: datetime
|
|
|
|
@staticmethod
|
|
def from_db(model: UserNotificationBatch) -> "UserNotificationBatchDTO":
|
|
return UserNotificationBatchDTO(
|
|
user_id=model.userId,
|
|
type=model.type,
|
|
notifications=[
|
|
UserNotificationEventDTO.from_db(notification)
|
|
for notification in model.Notifications or []
|
|
],
|
|
created_at=model.createdAt,
|
|
updated_at=model.updatedAt,
|
|
)
|
|
|
|
|
|
def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
|
return {
|
|
NotificationType.AGENT_RUN: timedelta(days=1),
|
|
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
|
|
NotificationType.LOW_BALANCE: timedelta(minutes=60),
|
|
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
|
|
NotificationType.CONTINUOUS_AGENT_ERROR: timedelta(minutes=60),
|
|
}[notification_type]
|
|
|
|
|
|
async def create_or_add_to_user_notification_batch(
|
|
user_id: str,
|
|
notification_type: NotificationType,
|
|
notification_data: NotificationEventModel,
|
|
) -> UserNotificationBatchDTO:
|
|
try:
|
|
if not notification_data.data:
|
|
raise ValueError("Notification data must be provided")
|
|
|
|
# Serialize the data
|
|
json_data: Json = SafeJson(notification_data.data.model_dump())
|
|
|
|
# First try to find existing batch
|
|
existing_batch = await UserNotificationBatch.prisma().find_unique(
|
|
where={
|
|
"userId_type": {
|
|
"userId": user_id,
|
|
"type": notification_type,
|
|
}
|
|
},
|
|
include={"Notifications": True},
|
|
)
|
|
|
|
if not existing_batch:
|
|
resp = await UserNotificationBatch.prisma().create(
|
|
data=UserNotificationBatchCreateInput(
|
|
userId=user_id,
|
|
type=notification_type,
|
|
Notifications={
|
|
"create": [
|
|
NotificationEventCreateInput(
|
|
type=notification_type,
|
|
data=json_data,
|
|
)
|
|
]
|
|
},
|
|
),
|
|
include={"Notifications": True},
|
|
)
|
|
return UserNotificationBatchDTO.from_db(resp)
|
|
else:
|
|
resp = await UserNotificationBatch.prisma().update(
|
|
where={"id": existing_batch.id},
|
|
data={
|
|
"Notifications": {
|
|
"create": [
|
|
NotificationEventCreateInput(
|
|
type=notification_type,
|
|
data=json_data,
|
|
)
|
|
]
|
|
}
|
|
},
|
|
include={"Notifications": True},
|
|
)
|
|
if not resp:
|
|
raise DatabaseError(
|
|
f"Failed to add notification event to existing batch {existing_batch.id}"
|
|
)
|
|
return UserNotificationBatchDTO.from_db(resp)
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to create or add to notification batch for user {user_id} and type {notification_type}: {e}"
|
|
) from e
|
|
|
|
|
|
async def get_user_notification_oldest_message_in_batch(
|
|
user_id: str,
|
|
notification_type: NotificationType,
|
|
) -> UserNotificationEventDTO | None:
|
|
try:
|
|
batch = await UserNotificationBatch.prisma().find_first(
|
|
where={"userId": user_id, "type": notification_type},
|
|
include={"Notifications": True},
|
|
)
|
|
if not batch:
|
|
return None
|
|
if not batch.Notifications:
|
|
return None
|
|
sorted_notifications = sorted(batch.Notifications, key=lambda x: x.createdAt)
|
|
|
|
return (
|
|
UserNotificationEventDTO.from_db(sorted_notifications[0])
|
|
if sorted_notifications
|
|
else None
|
|
)
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to get user notification last message in batch for user {user_id} and type {notification_type}: {e}"
|
|
) from e
|
|
|
|
|
|
async def empty_user_notification_batch(
|
|
user_id: str, notification_type: NotificationType
|
|
) -> None:
|
|
try:
|
|
async with transaction() as tx:
|
|
await tx.notificationevent.delete_many(
|
|
where={
|
|
"UserNotificationBatch": {
|
|
"is": {"userId": user_id, "type": notification_type}
|
|
}
|
|
}
|
|
)
|
|
|
|
await tx.usernotificationbatch.delete_many(
|
|
where=UserNotificationBatchWhereInput(
|
|
userId=user_id,
|
|
type=notification_type,
|
|
)
|
|
)
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to empty user notification batch for user {user_id} and type {notification_type}: {e}"
|
|
) from e
|
|
|
|
|
|
async def clear_all_user_notification_batches(user_id: str) -> None:
|
|
"""Clear ALL notification batches for a user across all types.
|
|
|
|
Used when user's email is bounced/inactive and we should stop
|
|
trying to send them ANY emails.
|
|
"""
|
|
try:
|
|
# Delete all notification events for this user
|
|
await NotificationEvent.prisma().delete_many(
|
|
where={"UserNotificationBatch": {"is": {"userId": user_id}}}
|
|
)
|
|
|
|
# Delete all batches for this user
|
|
await UserNotificationBatch.prisma().delete_many(where={"userId": user_id})
|
|
|
|
logger.info(f"Cleared all notification batches for user {user_id}")
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to clear all notification batches for user {user_id}: {e}"
|
|
) from e
|
|
|
|
|
|
async def remove_notifications_from_batch(
|
|
user_id: str, notification_type: NotificationType, notification_ids: list[str]
|
|
) -> None:
|
|
"""Remove specific notifications from a user's batch by their IDs.
|
|
|
|
This is used after successful sending to remove only the
|
|
sent notifications, preventing duplicates on retry.
|
|
"""
|
|
if not notification_ids:
|
|
return
|
|
|
|
try:
|
|
# Delete the specific notification events
|
|
deleted_count = await NotificationEvent.prisma().delete_many(
|
|
where={
|
|
"id": {"in": notification_ids},
|
|
"UserNotificationBatch": {
|
|
"is": {"userId": user_id, "type": notification_type}
|
|
},
|
|
}
|
|
)
|
|
|
|
logger.info(
|
|
f"Removed {deleted_count} notifications from batch for user {user_id}"
|
|
)
|
|
|
|
# Check if batch is now empty and delete it if so
|
|
remaining = await NotificationEvent.prisma().count(
|
|
where={
|
|
"UserNotificationBatch": {
|
|
"is": {"userId": user_id, "type": notification_type}
|
|
}
|
|
}
|
|
)
|
|
|
|
if remaining == 0:
|
|
await UserNotificationBatch.prisma().delete_many(
|
|
where=UserNotificationBatchWhereInput(
|
|
userId=user_id,
|
|
type=notification_type,
|
|
)
|
|
)
|
|
logger.info(
|
|
f"Deleted empty batch for user {user_id} and type {notification_type}"
|
|
)
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to remove notifications from batch for user {user_id} and type {notification_type}: {e}"
|
|
) from e
|
|
|
|
|
|
async def get_user_notification_batch(
|
|
user_id: str,
|
|
notification_type: NotificationType,
|
|
) -> UserNotificationBatchDTO | None:
|
|
try:
|
|
batch = await UserNotificationBatch.prisma().find_first(
|
|
where={"userId": user_id, "type": notification_type},
|
|
include={"Notifications": True},
|
|
)
|
|
return UserNotificationBatchDTO.from_db(batch) if batch else None
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to get user notification batch for user {user_id} and type {notification_type}: {e}"
|
|
) from e
|
|
|
|
|
|
async def get_all_batches_by_type(
|
|
notification_type: NotificationType,
|
|
) -> list[UserNotificationBatchDTO]:
|
|
try:
|
|
batches = await UserNotificationBatch.prisma().find_many(
|
|
where={
|
|
"type": notification_type,
|
|
"Notifications": {
|
|
"some": {} # Only return batches with at least one notification
|
|
},
|
|
},
|
|
include={"Notifications": True},
|
|
)
|
|
return [UserNotificationBatchDTO.from_db(batch) for batch in batches]
|
|
except Exception as e:
|
|
raise DatabaseError(
|
|
f"Failed to get all batches by type {notification_type}: {e}"
|
|
) from e
|