mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fixing db queries
This commit is contained in:
@@ -2,9 +2,15 @@
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
from prisma.types import (
|
||||
ChatMessageCreateInput,
|
||||
ChatSessionCreateInput,
|
||||
ChatSessionUpdateInput,
|
||||
)
|
||||
|
||||
from backend.util import json
|
||||
|
||||
@@ -13,10 +19,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
||||
"""Get a chat session by ID from the database."""
|
||||
return await PrismaChatSession.prisma().find_unique(
|
||||
session = await PrismaChatSession.prisma().find_unique(
|
||||
where={"id": session_id},
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def create_chat_session(
|
||||
@@ -24,29 +34,30 @@ async def create_chat_session(
|
||||
user_id: str | None,
|
||||
) -> PrismaChatSession:
|
||||
"""Create a new chat session in the database."""
|
||||
data: ChatSessionCreateInput = {
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": json.dumps({}),
|
||||
"successfulAgentRuns": json.dumps({}),
|
||||
"successfulAgentSchedules": json.dumps({}),
|
||||
}
|
||||
return await PrismaChatSession.prisma().create(
|
||||
data={
|
||||
"id": session_id,
|
||||
"userId": user_id,
|
||||
"credentials": json.dumps({}),
|
||||
"successfulAgentRuns": json.dumps({}),
|
||||
"successfulAgentSchedules": json.dumps({}),
|
||||
},
|
||||
data=data,
|
||||
include={"Messages": True},
|
||||
)
|
||||
|
||||
|
||||
async def update_chat_session(
|
||||
session_id: str,
|
||||
credentials: dict | None = None,
|
||||
successful_agent_runs: dict | None = None,
|
||||
successful_agent_schedules: dict | None = None,
|
||||
credentials: dict[str, Any] | None = None,
|
||||
successful_agent_runs: dict[str, Any] | None = None,
|
||||
successful_agent_schedules: dict[str, Any] | None = None,
|
||||
total_prompt_tokens: int | None = None,
|
||||
total_completion_tokens: int | None = None,
|
||||
title: str | None = None,
|
||||
) -> PrismaChatSession | None:
|
||||
"""Update a chat session's metadata."""
|
||||
data: dict = {"updatedAt": datetime.now(UTC)}
|
||||
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
||||
|
||||
if credentials is not None:
|
||||
data["credentials"] = json.dumps(credentials)
|
||||
@@ -61,11 +72,14 @@ async def update_chat_session(
|
||||
if title is not None:
|
||||
data["title"] = title
|
||||
|
||||
return await PrismaChatSession.prisma().update(
|
||||
session = await PrismaChatSession.prisma().update(
|
||||
where={"id": session_id},
|
||||
data=data,
|
||||
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
||||
include={"Messages": True},
|
||||
)
|
||||
if session and session.Messages:
|
||||
session.Messages.sort(key=lambda m: m.sequence)
|
||||
return session
|
||||
|
||||
|
||||
async def add_chat_message(
|
||||
@@ -76,12 +90,12 @@ async def add_chat_message(
|
||||
name: str | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
refusal: str | None = None,
|
||||
tool_calls: list[dict] | None = None,
|
||||
function_call: dict | None = None,
|
||||
tool_calls: list[dict[str, Any]] | None = None,
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
data: dict = {
|
||||
"sessionId": session_id,
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
@@ -110,7 +124,7 @@ async def add_chat_message(
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
session_id: str,
|
||||
messages: list[dict],
|
||||
messages: list[dict[str, Any]],
|
||||
start_sequence: int,
|
||||
) -> list[PrismaChatMessage]:
|
||||
"""Add multiple messages to a chat session in a batch."""
|
||||
@@ -119,8 +133,8 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
data: dict = {
|
||||
"sessionId": session_id,
|
||||
data: ChatMessageCreateInput = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
@@ -158,7 +172,7 @@ async def get_user_chat_sessions(
|
||||
"""Get chat sessions for a user, ordered by most recent."""
|
||||
return await PrismaChatSession.prisma().find_many(
|
||||
where={"userId": user_id},
|
||||
order_by={"updatedAt": "desc"},
|
||||
order={"updatedAt": "desc"},
|
||||
take=limit,
|
||||
skip=offset,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user