feat(backend): add some more useful db queries

This commit is contained in:
Nicholas Tindle
2025-02-10 15:17:03 -06:00
parent 3f813c09a9
commit 9aba2f8c63
2 changed files with 97 additions and 24 deletions

View File

@@ -1,9 +1,8 @@
import logging
from datetime import datetime, timedelta
from enum import Enum
from typing import Annotated, Generic, Optional, TypeVar, Union
from typing import Annotated, Generic, Optional, Type, TypeVar, Union
import prisma
from prisma import Json
from prisma.enums import NotificationType
from prisma.models import NotificationEvent, UserNotificationBatch
@@ -12,6 +11,8 @@ from prisma.types import UserNotificationBatchWhereInput
# from backend.notifications.models import NotificationEvent
from pydantic import BaseModel, EmailStr, Field, field_validator
from .db import transaction
logger = logging.getLogger(__name__)
@@ -110,6 +111,13 @@ NotificationData = Annotated[
]
class NotificationEventDTO(BaseModel):
user_id: str
type: NotificationType
data: dict
created_at: datetime = Field(default_factory=datetime.now)
class NotificationEventModel(BaseModel, Generic[T_co]):
user_id: str
type: NotificationType
@@ -131,6 +139,21 @@ class NotificationEventModel(BaseModel, Generic[T_co]):
return NotificationTypeOverride(self.type).template
def get_data_type(
notification_type: NotificationType,
) -> Type[BaseNotificationData]:
return {
NotificationType.AGENT_RUN: AgentRunData,
NotificationType.ZERO_BALANCE: ZeroBalanceData,
NotificationType.LOW_BALANCE: LowBalanceData,
NotificationType.BLOCK_EXECUTION_FAILED: BlockExecutionFailedData,
NotificationType.CONTINUOUS_AGENT_ERROR: ContinuousAgentErrorData,
NotificationType.DAILY_SUMMARY: DailySummaryData,
NotificationType.WEEKLY_SUMMARY: WeeklySummaryData,
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
}[notification_type]
class NotificationBatch(BaseModel):
user_id: str
events: list[NotificationEvent]
@@ -151,7 +174,7 @@ class NotificationTypeOverride:
def strategy(self) -> BatchingStrategy:
BATCHING_RULES = {
# These are batched by the notification service
NotificationType.AGENT_RUN: BatchingStrategy.HOURLY,
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
# These are batched by the notification service, but with a backoff strategy
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
@@ -190,7 +213,7 @@ class NotificationPreference(BaseModel):
def get_batch_delay(notification_type: NotificationType) -> timedelta:
return {
NotificationType.AGENT_RUN: timedelta(minutes=1),
NotificationType.AGENT_RUN: timedelta(seconds=1),
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
NotificationType.LOW_BALANCE: timedelta(minutes=60),
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
@@ -201,30 +224,68 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
async def create_or_add_to_user_notification_batch(
user_id: str,
notification_type: NotificationType,
data: NotificationEventModel[T_co],
) -> UserNotificationBatch:
data: str, # type: 'NotificationEventModel'
) -> dict:
logger.info(
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
)
notification_data = Json(data.model_dump(exclude={"user_id", "type", "created_at"}))
notification_data = NotificationEventModel[
get_data_type(notification_type)
].model_validate_json(data)
upsert = await UserNotificationBatch.prisma().upsert(
where=UserNotificationBatchWhereInput(userId=user_id, type=notification_type),
data={
"create": {
# Serialize the data
# serialized_data = json.dumps(notification_data.data.model_dump())
json_data: Json = Json(notification_data.data.model_dump_json())
# First try to find existing batch
existing_batch = await UserNotificationBatch.prisma().find_unique(
where={
"userId_type": {
"userId": user_id,
"type": notification_type,
"notifications": {
"create": [{"type": notification_type, "data": notification_data}]
},
},
"update": {
"notifications": {
"create": [{"type": notification_type, "data": notification_data}]
},
},
}
},
include={"notifications": True},
)
return upsert
if not existing_batch:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
}
)
# Create new batch
resp = await tx.usernotificationbatch.create(
data={
"userId": user_id,
"type": notification_type,
"notifications": {"connect": [{"id": notification_event.id}]},
},
include={"notifications": True},
)
return resp.model_dump()
else:
async with transaction() as tx:
notification_event = await tx.notificationevent.create(
data={
"type": notification_type,
"data": json_data,
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
}
)
# Add to existing batch
resp = await tx.usernotificationbatch.update(
where={"id": existing_batch.id},
data={"notifications": {"connect": [{"id": notification_event.id}]}},
include={"notifications": True},
)
if not resp:
raise Exception("Failed to add to existing batch")
return resp.model_dump()
async def get_user_notification_last_message_in_batch(
@@ -245,8 +306,8 @@ async def get_user_notification_last_message_in_batch(
async def empty_user_notification_batch(
user_id: str, notification_type: NotificationType
) -> None:
async with prisma.Prisma().tx() as transaction:
await transaction.notificationevent.delete_many(
async with transaction() as tx:
await tx.notificationevent.delete_many(
where={
"UserNotificationBatch": {
"is": {"userId": user_id, "type": notification_type}
@@ -254,9 +315,19 @@ async def empty_user_notification_batch(
}
)
await transaction.usernotificationbatch.delete_many(
await tx.usernotificationbatch.delete_many(
where=UserNotificationBatchWhereInput(
userId=user_id,
type=notification_type,
)
)
async def get_user_notification_batch(
user_id: str,
notification_type: NotificationType,
) -> UserNotificationBatch | None:
return await UserNotificationBatch.prisma().find_first(
where={"userId": user_id, "type": notification_type},
include={"notifications": True},
)

View File

@@ -21,6 +21,7 @@ from backend.data.graph import get_graph, get_node
from backend.data.notifications import (
create_or_add_to_user_notification_batch,
empty_user_notification_batch,
get_user_notification_batch,
get_user_notification_last_message_in_batch,
)
from backend.data.user import (
@@ -117,3 +118,4 @@ class DatabaseManager(AppService):
get_user_notification_last_message_in_batch
)
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)