fixing db queries

This commit is contained in:
Swifty
2025-12-16 16:31:22 +01:00
parent a8c68b585a
commit 06ce6fa9a1

View File

@@ -2,9 +2,15 @@
import logging
from datetime import UTC, datetime
from typing import Any
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
from prisma.types import (
ChatMessageCreateInput,
ChatSessionCreateInput,
ChatSessionUpdateInput,
)
from backend.util import json
@@ -13,10 +19,14 @@ logger = logging.getLogger(__name__)
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
"""Get a chat session by ID from the database."""
return await PrismaChatSession.prisma().find_unique(
session = await PrismaChatSession.prisma().find_unique(
where={"id": session_id},
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
if session and session.Messages:
# Sort messages by sequence in Python since Prisma doesn't support order_by in include
session.Messages.sort(key=lambda m: m.sequence)
return session
async def create_chat_session(
@@ -24,29 +34,30 @@ async def create_chat_session(
user_id: str | None,
) -> PrismaChatSession:
"""Create a new chat session in the database."""
data: ChatSessionCreateInput = {
"id": session_id,
"userId": user_id,
"credentials": json.dumps({}),
"successfulAgentRuns": json.dumps({}),
"successfulAgentSchedules": json.dumps({}),
}
return await PrismaChatSession.prisma().create(
data={
"id": session_id,
"userId": user_id,
"credentials": json.dumps({}),
"successfulAgentRuns": json.dumps({}),
"successfulAgentSchedules": json.dumps({}),
},
data=data,
include={"Messages": True},
)
async def update_chat_session(
session_id: str,
credentials: dict | None = None,
successful_agent_runs: dict | None = None,
successful_agent_schedules: dict | None = None,
credentials: dict[str, Any] | None = None,
successful_agent_runs: dict[str, Any] | None = None,
successful_agent_schedules: dict[str, Any] | None = None,
total_prompt_tokens: int | None = None,
total_completion_tokens: int | None = None,
title: str | None = None,
) -> PrismaChatSession | None:
"""Update a chat session's metadata."""
data: dict = {"updatedAt": datetime.now(UTC)}
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
if credentials is not None:
data["credentials"] = json.dumps(credentials)
@@ -61,11 +72,14 @@ async def update_chat_session(
if title is not None:
data["title"] = title
return await PrismaChatSession.prisma().update(
session = await PrismaChatSession.prisma().update(
where={"id": session_id},
data=data,
include={"Messages": {"order_by": {"sequence": "asc"}}},
include={"Messages": True},
)
if session and session.Messages:
session.Messages.sort(key=lambda m: m.sequence)
return session
async def add_chat_message(
@@ -76,12 +90,12 @@ async def add_chat_message(
name: str | None = None,
tool_call_id: str | None = None,
refusal: str | None = None,
tool_calls: list[dict] | None = None,
function_call: dict | None = None,
tool_calls: list[dict[str, Any]] | None = None,
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
data: dict = {
"sessionId": session_id,
data: ChatMessageCreateInput = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
@@ -110,7 +124,7 @@ async def add_chat_message(
async def add_chat_messages_batch(
session_id: str,
messages: list[dict],
messages: list[dict[str, Any]],
start_sequence: int,
) -> list[PrismaChatMessage]:
"""Add multiple messages to a chat session in a batch."""
@@ -119,8 +133,8 @@ async def add_chat_messages_batch(
created_messages = []
for i, msg in enumerate(messages):
data: dict = {
"sessionId": session_id,
data: ChatMessageCreateInput = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
@@ -158,7 +172,7 @@ async def get_user_chat_sessions(
"""Get chat sessions for a user, ordered by most recent."""
return await PrismaChatSession.prisma().find_many(
where={"userId": user_id},
order_by={"updatedAt": "desc"},
order={"updatedAt": "desc"},
take=limit,
skip=offset,
)