fix tests

This commit is contained in:
Swifty
2026-01-08 10:43:15 +01:00
parent fd768e8f3f
commit eee478fa31
2 changed files with 56 additions and 34 deletions

View File

@@ -2,7 +2,7 @@
import logging
from datetime import UTC, datetime
from typing import Any
from typing import Any, cast
from prisma.models import ChatMessage as PrismaChatMessage
from prisma.models import ChatSession as PrismaChatSession
@@ -94,17 +94,30 @@ async def add_chat_message(
function_call: dict[str, Any] | None = None,
) -> PrismaChatMessage:
"""Add a message to a chat session."""
data = ChatMessageCreateInput(
Session={"connect": {"id": session_id}},
role=role,
sequence=sequence,
content=content,
name=name,
toolCallId=tool_call_id,
refusal=refusal,
toolCalls=SafeJson(tool_calls) if tool_calls is not None else None,
functionCall=SafeJson(function_call) if function_call is not None else None,
)
# Build the input dict dynamically - only include optional fields when they
# have values, as Prisma TypedDict validation fails when optional fields
# are explicitly set to None
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": role,
"sequence": sequence,
}
# Add optional string fields
if content is not None:
data["content"] = content
if name is not None:
data["name"] = name
if tool_call_id is not None:
data["toolCallId"] = tool_call_id
if refusal is not None:
data["refusal"] = refusal
# Add optional JSON fields only when they have values
if tool_calls is not None:
data["toolCalls"] = SafeJson(tool_calls)
if function_call is not None:
data["functionCall"] = SafeJson(function_call)
# Update session's updatedAt timestamp
await PrismaChatSession.prisma().update(
@@ -112,7 +125,9 @@ async def add_chat_message(
data={"updatedAt": datetime.now(UTC)},
)
return await PrismaChatMessage.prisma().create(data=data)
return await PrismaChatMessage.prisma().create(
data=cast(ChatMessageCreateInput, data)
)
async def add_chat_messages_batch(
@@ -126,27 +141,34 @@ async def add_chat_messages_batch(
created_messages = []
for i, msg in enumerate(messages):
data = ChatMessageCreateInput(
Session={"connect": {"id": session_id}},
role=msg["role"],
sequence=start_sequence + i,
content=msg.get("content"),
name=msg.get("name"),
toolCallId=msg.get("tool_call_id"),
refusal=msg.get("refusal"),
toolCalls=(
SafeJson(msg["tool_calls"])
if msg.get("tool_calls") is not None
else None
),
functionCall=(
SafeJson(msg["function_call"])
if msg.get("function_call") is not None
else None
),
)
# Build the input dict dynamically - only include optional JSON fields
# when they have values, as Prisma TypedDict validation fails when
# optional fields are explicitly set to None
data: dict[str, Any] = {
"Session": {"connect": {"id": session_id}},
"role": msg["role"],
"sequence": start_sequence + i,
}
created = await PrismaChatMessage.prisma().create(data=data)
# Add optional string fields
if msg.get("content") is not None:
data["content"] = msg["content"]
if msg.get("name") is not None:
data["name"] = msg["name"]
if msg.get("tool_call_id") is not None:
data["toolCallId"] = msg["tool_call_id"]
if msg.get("refusal") is not None:
data["refusal"] = msg["refusal"]
# Add optional JSON fields only when they have values
if msg.get("tool_calls") is not None:
data["toolCalls"] = SafeJson(msg["tool_calls"])
if msg.get("function_call") is not None:
data["functionCall"] = SafeJson(msg["function_call"])
created = await PrismaChatMessage.prisma().create(
data=cast(ChatMessageCreateInput, data)
)
created_messages.append(created)
# Update session's updatedAt timestamp

View File

@@ -78,7 +78,7 @@ async def test_chatsession_db_storage():
# Create session with messages including assistant message
s = ChatSession.new(user_id=None)
s.messages = messages # Contains user, assistant, and tool messages
assert s.session_id is not None, "Session id is not set"
# Upsert to save to both cache and DB
s = await upsert_chat_session(s)