mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
fix tests
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from prisma.models import ChatMessage as PrismaChatMessage
|
||||
from prisma.models import ChatSession as PrismaChatSession
|
||||
@@ -94,17 +94,30 @@ async def add_chat_message(
|
||||
function_call: dict[str, Any] | None = None,
|
||||
) -> PrismaChatMessage:
|
||||
"""Add a message to a chat session."""
|
||||
data = ChatMessageCreateInput(
|
||||
Session={"connect": {"id": session_id}},
|
||||
role=role,
|
||||
sequence=sequence,
|
||||
content=content,
|
||||
name=name,
|
||||
toolCallId=tool_call_id,
|
||||
refusal=refusal,
|
||||
toolCalls=SafeJson(tool_calls) if tool_calls is not None else None,
|
||||
functionCall=SafeJson(function_call) if function_call is not None else None,
|
||||
)
|
||||
# Build the input dict dynamically - only include optional fields when they
|
||||
# have values, as Prisma TypedDict validation fails when optional fields
|
||||
# are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": role,
|
||||
"sequence": sequence,
|
||||
}
|
||||
|
||||
# Add optional string fields
|
||||
if content is not None:
|
||||
data["content"] = content
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if tool_call_id is not None:
|
||||
data["toolCallId"] = tool_call_id
|
||||
if refusal is not None:
|
||||
data["refusal"] = refusal
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if tool_calls is not None:
|
||||
data["toolCalls"] = SafeJson(tool_calls)
|
||||
if function_call is not None:
|
||||
data["functionCall"] = SafeJson(function_call)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
await PrismaChatSession.prisma().update(
|
||||
@@ -112,7 +125,9 @@ async def add_chat_message(
|
||||
data={"updatedAt": datetime.now(UTC)},
|
||||
)
|
||||
|
||||
return await PrismaChatMessage.prisma().create(data=data)
|
||||
return await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
|
||||
|
||||
async def add_chat_messages_batch(
|
||||
@@ -126,27 +141,34 @@ async def add_chat_messages_batch(
|
||||
|
||||
created_messages = []
|
||||
for i, msg in enumerate(messages):
|
||||
data = ChatMessageCreateInput(
|
||||
Session={"connect": {"id": session_id}},
|
||||
role=msg["role"],
|
||||
sequence=start_sequence + i,
|
||||
content=msg.get("content"),
|
||||
name=msg.get("name"),
|
||||
toolCallId=msg.get("tool_call_id"),
|
||||
refusal=msg.get("refusal"),
|
||||
toolCalls=(
|
||||
SafeJson(msg["tool_calls"])
|
||||
if msg.get("tool_calls") is not None
|
||||
else None
|
||||
),
|
||||
functionCall=(
|
||||
SafeJson(msg["function_call"])
|
||||
if msg.get("function_call") is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
# Build the input dict dynamically - only include optional JSON fields
|
||||
# when they have values, as Prisma TypedDict validation fails when
|
||||
# optional fields are explicitly set to None
|
||||
data: dict[str, Any] = {
|
||||
"Session": {"connect": {"id": session_id}},
|
||||
"role": msg["role"],
|
||||
"sequence": start_sequence + i,
|
||||
}
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(data=data)
|
||||
# Add optional string fields
|
||||
if msg.get("content") is not None:
|
||||
data["content"] = msg["content"]
|
||||
if msg.get("name") is not None:
|
||||
data["name"] = msg["name"]
|
||||
if msg.get("tool_call_id") is not None:
|
||||
data["toolCallId"] = msg["tool_call_id"]
|
||||
if msg.get("refusal") is not None:
|
||||
data["refusal"] = msg["refusal"]
|
||||
|
||||
# Add optional JSON fields only when they have values
|
||||
if msg.get("tool_calls") is not None:
|
||||
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
||||
if msg.get("function_call") is not None:
|
||||
data["functionCall"] = SafeJson(msg["function_call"])
|
||||
|
||||
created = await PrismaChatMessage.prisma().create(
|
||||
data=cast(ChatMessageCreateInput, data)
|
||||
)
|
||||
created_messages.append(created)
|
||||
|
||||
# Update session's updatedAt timestamp
|
||||
|
||||
@@ -78,7 +78,7 @@ async def test_chatsession_db_storage():
|
||||
# Create session with messages including assistant message
|
||||
s = ChatSession.new(user_id=None)
|
||||
s.messages = messages # Contains user, assistant, and tool messages
|
||||
|
||||
assert s.session_id is not None, "Session id is not set"
|
||||
# Upsert to save to both cache and DB
|
||||
s = await upsert_chat_session(s)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user