mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): add some more useful db queries
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from enum import Enum
|
||||
from typing import Annotated, Generic, Optional, TypeVar, Union
|
||||
from typing import Annotated, Generic, Optional, Type, TypeVar, Union
|
||||
|
||||
import prisma
|
||||
from prisma import Json
|
||||
from prisma.enums import NotificationType
|
||||
from prisma.models import NotificationEvent, UserNotificationBatch
|
||||
@@ -12,6 +11,8 @@ from prisma.types import UserNotificationBatchWhereInput
|
||||
# from backend.notifications.models import NotificationEvent
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
from .db import transaction
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -110,6 +111,13 @@ NotificationData = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class NotificationEventDTO(BaseModel):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
data: dict
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class NotificationEventModel(BaseModel, Generic[T_co]):
|
||||
user_id: str
|
||||
type: NotificationType
|
||||
@@ -131,6 +139,21 @@ class NotificationEventModel(BaseModel, Generic[T_co]):
|
||||
return NotificationTypeOverride(self.type).template
|
||||
|
||||
|
||||
def get_data_type(
|
||||
notification_type: NotificationType,
|
||||
) -> Type[BaseNotificationData]:
|
||||
return {
|
||||
NotificationType.AGENT_RUN: AgentRunData,
|
||||
NotificationType.ZERO_BALANCE: ZeroBalanceData,
|
||||
NotificationType.LOW_BALANCE: LowBalanceData,
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: BlockExecutionFailedData,
|
||||
NotificationType.CONTINUOUS_AGENT_ERROR: ContinuousAgentErrorData,
|
||||
NotificationType.DAILY_SUMMARY: DailySummaryData,
|
||||
NotificationType.WEEKLY_SUMMARY: WeeklySummaryData,
|
||||
NotificationType.MONTHLY_SUMMARY: MonthlySummaryData,
|
||||
}[notification_type]
|
||||
|
||||
|
||||
class NotificationBatch(BaseModel):
|
||||
user_id: str
|
||||
events: list[NotificationEvent]
|
||||
@@ -151,7 +174,7 @@ class NotificationTypeOverride:
|
||||
def strategy(self) -> BatchingStrategy:
|
||||
BATCHING_RULES = {
|
||||
# These are batched by the notification service
|
||||
NotificationType.AGENT_RUN: BatchingStrategy.HOURLY,
|
||||
NotificationType.AGENT_RUN: BatchingStrategy.IMMEDIATE,
|
||||
# These are batched by the notification service, but with a backoff strategy
|
||||
NotificationType.ZERO_BALANCE: BatchingStrategy.BACKOFF,
|
||||
NotificationType.LOW_BALANCE: BatchingStrategy.BACKOFF,
|
||||
@@ -190,7 +213,7 @@ class NotificationPreference(BaseModel):
|
||||
|
||||
def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
return {
|
||||
NotificationType.AGENT_RUN: timedelta(minutes=1),
|
||||
NotificationType.AGENT_RUN: timedelta(seconds=1),
|
||||
NotificationType.ZERO_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.LOW_BALANCE: timedelta(minutes=60),
|
||||
NotificationType.BLOCK_EXECUTION_FAILED: timedelta(minutes=60),
|
||||
@@ -201,30 +224,68 @@ def get_batch_delay(notification_type: NotificationType) -> timedelta:
|
||||
async def create_or_add_to_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
data: NotificationEventModel[T_co],
|
||||
) -> UserNotificationBatch:
|
||||
data: str, # type: 'NotificationEventModel'
|
||||
) -> dict:
|
||||
logger.info(
|
||||
f"Creating or adding to notification batch for {user_id} with type {notification_type} and data {data}"
|
||||
)
|
||||
|
||||
notification_data = Json(data.model_dump(exclude={"user_id", "type", "created_at"}))
|
||||
notification_data = NotificationEventModel[
|
||||
get_data_type(notification_type)
|
||||
].model_validate_json(data)
|
||||
|
||||
upsert = await UserNotificationBatch.prisma().upsert(
|
||||
where=UserNotificationBatchWhereInput(userId=user_id, type=notification_type),
|
||||
data={
|
||||
"create": {
|
||||
# Serialize the data
|
||||
# serialized_data = json.dumps(notification_data.data.model_dump())
|
||||
json_data: Json = Json(notification_data.data.model_dump_json())
|
||||
|
||||
# First try to find existing batch
|
||||
existing_batch = await UserNotificationBatch.prisma().find_unique(
|
||||
where={
|
||||
"userId_type": {
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"notifications": {
|
||||
"create": [{"type": notification_type, "data": notification_data}]
|
||||
},
|
||||
},
|
||||
"update": {
|
||||
"notifications": {
|
||||
"create": [{"type": notification_type, "data": notification_data}]
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
include={"notifications": True},
|
||||
)
|
||||
return upsert
|
||||
|
||||
if not existing_batch:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
}
|
||||
)
|
||||
|
||||
# Create new batch
|
||||
resp = await tx.usernotificationbatch.create(
|
||||
data={
|
||||
"userId": user_id,
|
||||
"type": notification_type,
|
||||
"notifications": {"connect": [{"id": notification_event.id}]},
|
||||
},
|
||||
include={"notifications": True},
|
||||
)
|
||||
return resp.model_dump()
|
||||
else:
|
||||
async with transaction() as tx:
|
||||
notification_event = await tx.notificationevent.create(
|
||||
data={
|
||||
"type": notification_type,
|
||||
"data": json_data,
|
||||
"UserNotificationBatch": {"connect": {"id": existing_batch.id}},
|
||||
}
|
||||
)
|
||||
# Add to existing batch
|
||||
resp = await tx.usernotificationbatch.update(
|
||||
where={"id": existing_batch.id},
|
||||
data={"notifications": {"connect": [{"id": notification_event.id}]}},
|
||||
include={"notifications": True},
|
||||
)
|
||||
if not resp:
|
||||
raise Exception("Failed to add to existing batch")
|
||||
return resp.model_dump()
|
||||
|
||||
|
||||
async def get_user_notification_last_message_in_batch(
|
||||
@@ -245,8 +306,8 @@ async def get_user_notification_last_message_in_batch(
|
||||
async def empty_user_notification_batch(
|
||||
user_id: str, notification_type: NotificationType
|
||||
) -> None:
|
||||
async with prisma.Prisma().tx() as transaction:
|
||||
await transaction.notificationevent.delete_many(
|
||||
async with transaction() as tx:
|
||||
await tx.notificationevent.delete_many(
|
||||
where={
|
||||
"UserNotificationBatch": {
|
||||
"is": {"userId": user_id, "type": notification_type}
|
||||
@@ -254,9 +315,19 @@ async def empty_user_notification_batch(
|
||||
}
|
||||
)
|
||||
|
||||
await transaction.usernotificationbatch.delete_many(
|
||||
await tx.usernotificationbatch.delete_many(
|
||||
where=UserNotificationBatchWhereInput(
|
||||
userId=user_id,
|
||||
type=notification_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def get_user_notification_batch(
|
||||
user_id: str,
|
||||
notification_type: NotificationType,
|
||||
) -> UserNotificationBatch | None:
|
||||
return await UserNotificationBatch.prisma().find_first(
|
||||
where={"userId": user_id, "type": notification_type},
|
||||
include={"notifications": True},
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.data.graph import get_graph, get_node
|
||||
from backend.data.notifications import (
|
||||
create_or_add_to_user_notification_batch,
|
||||
empty_user_notification_batch,
|
||||
get_user_notification_batch,
|
||||
get_user_notification_last_message_in_batch,
|
||||
)
|
||||
from backend.data.user import (
|
||||
@@ -117,3 +118,4 @@ class DatabaseManager(AppService):
|
||||
get_user_notification_last_message_in_batch
|
||||
)
|
||||
empty_user_notification_batch = exposed_run_and_wait(empty_user_notification_batch)
|
||||
get_user_notification_batch = exposed_run_and_wait(get_user_notification_batch)
|
||||
|
||||
Reference in New Issue
Block a user