mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
## Summary - Adds a session-level `dry_run` flag that forces ALL tool calls (`run_block`, `run_agent`) in a copilot/autopilot session to use dry-run simulation mode - Stores the flag in a typed `ChatSessionMetadata` JSON model on the `ChatSession` DB row, accessed via `session.dry_run` property - Adds `dry_run` to the AutoPilot block Input schema so graph builders can create dry-run autopilot nodes - Refactors multiple copilot tools from `**kwargs` to explicit parameters for type safety ## Changes - **Prisma schema**: Added `metadata` JSON column to `ChatSession` model with migration - **Python models**: Added `ChatSessionMetadata` model with `dry_run` field, added `metadata` field to `ChatSessionInfo` and `ChatSession`, updated `from_db()`, `new()`, and `create_chat_session()` - **Session propagation**: `set_execution_context(user_id, session)` called from `baseline/service.py` so tool handlers can read session-level flags via `session.dry_run` - **Tool enforcement**: `run_block` and `run_agent` check `session.dry_run` and force `dry_run=True` when set; `run_agent` blocks scheduling in dry-run sessions - **AutoPilot block**: Added `dry_run` input field, passes it when creating sessions - **Chat API**: Added `CreateSessionRequest` model with `dry_run` field to `POST /sessions` endpoint; added `metadata` to session responses - **Frontend**: Updated `useChatSession.ts` to pass body to the create session mutation - **Tool refactoring**: Multiple copilot tools refactored from `**kwargs` to explicit named parameters (agent_browser, manage_folders, workspace_files, connect_integration, agent_output, bash_exec, etc.) for better type safety ## Test plan - [x] Unit tests for `ChatSession.new()` with dry_run parameter - [x] Unit tests for `RunBlockTool` session dry_run override - [x] Unit tests for `RunAgentTool` session dry_run override - [x] Unit tests for session dry_run blocks scheduling - [x] Existing dry_run tests still pass (12/12) - [x] Existing permissions tests still pass - [x] All pre-commit hooks pass (ruff, isort, pyright, tsc) - [ ] Manual: Create autopilot session with `dry_run=True`, verify run_block/run_agent calls use simulation --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
397 lines
14 KiB
Python
397 lines
14 KiB
Python
"""Database operations for chat sessions."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
from prisma.errors import UniqueViolationError
|
|
from prisma.models import ChatMessage as PrismaChatMessage
|
|
from prisma.models import ChatSession as PrismaChatSession
|
|
from prisma.types import (
|
|
ChatMessageCreateInput,
|
|
ChatSessionCreateInput,
|
|
ChatSessionUpdateInput,
|
|
ChatSessionWhereInput,
|
|
)
|
|
|
|
from backend.data import db
|
|
from backend.util.json import SafeJson, sanitize_string
|
|
|
|
from .model import (
|
|
ChatMessage,
|
|
ChatSession,
|
|
ChatSessionInfo,
|
|
ChatSessionMetadata,
|
|
invalidate_session_cache,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def get_chat_session(session_id: str) -> ChatSession | None:
|
|
"""Get a chat session by ID from the database."""
|
|
session = await PrismaChatSession.prisma().find_unique(
|
|
where={"id": session_id},
|
|
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
|
)
|
|
return ChatSession.from_db(session) if session else None
|
|
|
|
|
|
async def create_chat_session(
|
|
session_id: str,
|
|
user_id: str,
|
|
metadata: ChatSessionMetadata | None = None,
|
|
) -> ChatSessionInfo:
|
|
"""Create a new chat session in the database."""
|
|
data = ChatSessionCreateInput(
|
|
id=session_id,
|
|
userId=user_id,
|
|
credentials=SafeJson({}),
|
|
successfulAgentRuns=SafeJson({}),
|
|
successfulAgentSchedules=SafeJson({}),
|
|
metadata=SafeJson((metadata or ChatSessionMetadata()).model_dump()),
|
|
)
|
|
prisma_session = await PrismaChatSession.prisma().create(data=data)
|
|
return ChatSessionInfo.from_db(prisma_session)
|
|
|
|
|
|
async def update_chat_session(
|
|
session_id: str,
|
|
credentials: dict[str, Any] | None = None,
|
|
successful_agent_runs: dict[str, Any] | None = None,
|
|
successful_agent_schedules: dict[str, Any] | None = None,
|
|
total_prompt_tokens: int | None = None,
|
|
total_completion_tokens: int | None = None,
|
|
title: str | None = None,
|
|
) -> ChatSession | None:
|
|
"""Update a chat session's mutable fields.
|
|
|
|
Note: ``metadata`` (which includes ``dry_run``) is intentionally omitted —
|
|
it is set once at creation time and treated as immutable for the lifetime
|
|
of the session.
|
|
"""
|
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
|
|
|
if credentials is not None:
|
|
data["credentials"] = SafeJson(credentials)
|
|
if successful_agent_runs is not None:
|
|
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
|
if successful_agent_schedules is not None:
|
|
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
|
if total_prompt_tokens is not None:
|
|
data["totalPromptTokens"] = total_prompt_tokens
|
|
if total_completion_tokens is not None:
|
|
data["totalCompletionTokens"] = total_completion_tokens
|
|
if title is not None:
|
|
data["title"] = title
|
|
|
|
session = await PrismaChatSession.prisma().update(
|
|
where={"id": session_id},
|
|
data=data,
|
|
include={"Messages": {"order_by": {"sequence": "asc"}}},
|
|
)
|
|
return ChatSession.from_db(session) if session else None
|
|
|
|
|
|
async def update_chat_session_title(
|
|
session_id: str,
|
|
user_id: str,
|
|
title: str,
|
|
*,
|
|
only_if_empty: bool = False,
|
|
) -> bool:
|
|
"""Update the title of a chat session, scoped to the owning user.
|
|
|
|
Always filters by (session_id, user_id) so callers cannot mutate another
|
|
user's session even when they know the session_id.
|
|
|
|
Args:
|
|
only_if_empty: When True, uses an atomic ``UPDATE WHERE title IS NULL``
|
|
guard so auto-generated titles never overwrite a user-set title.
|
|
|
|
Returns True if a row was updated, False otherwise (session not found,
|
|
wrong user, or — when only_if_empty — title was already set).
|
|
"""
|
|
where: ChatSessionWhereInput = {"id": session_id, "userId": user_id}
|
|
if only_if_empty:
|
|
where["title"] = None
|
|
result = await PrismaChatSession.prisma().update_many(
|
|
where=where,
|
|
data={"title": title, "updatedAt": datetime.now(UTC)},
|
|
)
|
|
return result > 0
|
|
|
|
|
|
async def add_chat_message(
|
|
session_id: str,
|
|
role: str,
|
|
sequence: int,
|
|
content: str | None = None,
|
|
name: str | None = None,
|
|
tool_call_id: str | None = None,
|
|
refusal: str | None = None,
|
|
tool_calls: list[dict[str, Any]] | None = None,
|
|
function_call: dict[str, Any] | None = None,
|
|
) -> ChatMessage:
|
|
"""Add a message to a chat session."""
|
|
# Build ChatMessageCreateInput with only non-None values
|
|
# (Prisma TypedDict rejects optional fields set to None)
|
|
data: ChatMessageCreateInput = {
|
|
"Session": {"connect": {"id": session_id}},
|
|
"role": role,
|
|
"sequence": sequence,
|
|
}
|
|
|
|
# Add optional string fields — sanitize to strip PostgreSQL-incompatible
|
|
# control characters (null bytes etc.) that may appear in tool outputs.
|
|
if content is not None:
|
|
data["content"] = sanitize_string(content)
|
|
if name is not None:
|
|
data["name"] = name
|
|
if tool_call_id is not None:
|
|
data["toolCallId"] = tool_call_id
|
|
if refusal is not None:
|
|
data["refusal"] = sanitize_string(refusal)
|
|
|
|
# Add optional JSON fields only when they have values
|
|
if tool_calls is not None:
|
|
data["toolCalls"] = SafeJson(tool_calls)
|
|
if function_call is not None:
|
|
data["functionCall"] = SafeJson(function_call)
|
|
|
|
# Run message create and session timestamp update in parallel for lower latency
|
|
_, message = await asyncio.gather(
|
|
PrismaChatSession.prisma().update(
|
|
where={"id": session_id},
|
|
data={"updatedAt": datetime.now(UTC)},
|
|
),
|
|
PrismaChatMessage.prisma().create(data=data),
|
|
)
|
|
return ChatMessage.from_db(message)
|
|
|
|
|
|
async def add_chat_messages_batch(
|
|
session_id: str,
|
|
messages: list[dict[str, Any]],
|
|
start_sequence: int,
|
|
) -> int:
|
|
"""Add multiple messages to a chat session in a batch.
|
|
|
|
Uses collision detection with retry: tries to create messages starting
|
|
at start_sequence. If a unique constraint violation occurs (e.g., the
|
|
streaming loop and long-running callback race), queries the latest
|
|
sequence and retries with the correct offset. This avoids unnecessary
|
|
upserts and DB queries in the common case (no collision).
|
|
|
|
Returns:
|
|
Next sequence number for the next message to be inserted. This equals
|
|
start_sequence + len(messages) and allows callers to update their
|
|
counters even when collision detection adjusts start_sequence.
|
|
"""
|
|
if not messages:
|
|
# No messages to add - return current count
|
|
return start_sequence
|
|
|
|
max_retries = 5
|
|
for attempt in range(max_retries):
|
|
try:
|
|
# Single timestamp for all messages and session update
|
|
now = datetime.now(UTC)
|
|
|
|
async with db.transaction() as tx:
|
|
# Build all message data
|
|
messages_data = []
|
|
for i, msg in enumerate(messages):
|
|
# Build ChatMessageCreateInput with only non-None values
|
|
# (Prisma TypedDict rejects optional fields set to None)
|
|
# Note: create_many doesn't support nested creates, use sessionId directly
|
|
data: ChatMessageCreateInput = {
|
|
"sessionId": session_id,
|
|
"role": msg["role"],
|
|
"sequence": start_sequence + i,
|
|
"createdAt": now,
|
|
}
|
|
|
|
# Add optional string fields — sanitize to strip
|
|
# PostgreSQL-incompatible control characters.
|
|
if msg.get("content") is not None:
|
|
data["content"] = sanitize_string(msg["content"])
|
|
if msg.get("name") is not None:
|
|
data["name"] = msg["name"]
|
|
if msg.get("tool_call_id") is not None:
|
|
data["toolCallId"] = msg["tool_call_id"]
|
|
if msg.get("refusal") is not None:
|
|
data["refusal"] = sanitize_string(msg["refusal"])
|
|
|
|
# Add optional JSON fields only when they have values
|
|
if msg.get("tool_calls") is not None:
|
|
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
|
if msg.get("function_call") is not None:
|
|
data["functionCall"] = SafeJson(msg["function_call"])
|
|
|
|
if msg.get("duration_ms") is not None:
|
|
data["durationMs"] = msg["duration_ms"]
|
|
|
|
messages_data.append(data)
|
|
|
|
# Run create_many and session update in parallel within transaction
|
|
# Both use the same timestamp for consistency
|
|
await asyncio.gather(
|
|
PrismaChatMessage.prisma(tx).create_many(data=messages_data),
|
|
PrismaChatSession.prisma(tx).update(
|
|
where={"id": session_id},
|
|
data={"updatedAt": now},
|
|
),
|
|
)
|
|
|
|
# Return next sequence number for counter sync
|
|
return start_sequence + len(messages)
|
|
|
|
except UniqueViolationError:
|
|
if attempt < max_retries - 1:
|
|
# Collision detected - query MAX(sequence)+1 and retry with correct offset
|
|
logger.info(
|
|
f"Collision detected for session {session_id} at sequence "
|
|
f"{start_sequence}, querying DB for latest sequence"
|
|
)
|
|
start_sequence = await get_next_sequence(session_id)
|
|
logger.info(
|
|
f"Retrying batch insert with start_sequence={start_sequence}"
|
|
)
|
|
continue
|
|
else:
|
|
# Max retries exceeded - propagate error
|
|
raise
|
|
|
|
# Should never reach here due to raise in exception handler
|
|
raise RuntimeError(f"Failed to insert messages after {max_retries} attempts")
|
|
|
|
|
|
async def get_user_chat_sessions(
|
|
user_id: str,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[ChatSessionInfo]:
|
|
"""Get chat sessions for a user, ordered by most recent."""
|
|
prisma_sessions = await PrismaChatSession.prisma().find_many(
|
|
where={"userId": user_id},
|
|
order={"updatedAt": "desc"},
|
|
take=limit,
|
|
skip=offset,
|
|
)
|
|
return [ChatSessionInfo.from_db(s) for s in prisma_sessions]
|
|
|
|
|
|
async def get_user_session_count(user_id: str) -> int:
|
|
"""Get the total number of chat sessions for a user."""
|
|
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
|
|
|
|
|
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
|
"""Delete a chat session and all its messages.
|
|
|
|
Args:
|
|
session_id: The session ID to delete.
|
|
user_id: If provided, validates that the session belongs to this user
|
|
before deletion. This prevents unauthorized deletion of other
|
|
users' sessions.
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise.
|
|
"""
|
|
try:
|
|
# Build typed where clause with optional user_id validation
|
|
where_clause: ChatSessionWhereInput = {"id": session_id}
|
|
if user_id is not None:
|
|
where_clause["userId"] = user_id
|
|
|
|
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
|
if result == 0:
|
|
logger.warning(
|
|
f"No session deleted for {session_id} "
|
|
f"(user_id validation: {user_id is not None})"
|
|
)
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
|
return False
|
|
|
|
|
|
async def get_next_sequence(session_id: str) -> int:
|
|
"""Get the next sequence number for a new message in this session.
|
|
|
|
Uses MAX(sequence) + 1 for robustness. Returns 0 if no messages exist.
|
|
More robust than COUNT(*) because it's immune to deleted messages.
|
|
|
|
Optimized to select only the sequence column using raw SQL.
|
|
The unique index on (sessionId, sequence) makes this query fast.
|
|
"""
|
|
results = await db.query_raw_with_schema(
|
|
'SELECT "sequence" FROM {schema_prefix}"ChatMessage" WHERE "sessionId" = $1 ORDER BY "sequence" DESC LIMIT 1',
|
|
session_id,
|
|
)
|
|
return 0 if not results else results[0]["sequence"] + 1
|
|
|
|
|
|
async def update_tool_message_content(
|
|
session_id: str,
|
|
tool_call_id: str,
|
|
new_content: str,
|
|
) -> bool:
|
|
"""Update the content of a tool message in chat history.
|
|
|
|
Used by background tasks to update pending operation messages with final results.
|
|
|
|
Args:
|
|
session_id: The chat session ID.
|
|
tool_call_id: The tool call ID to find the message.
|
|
new_content: The new content to set.
|
|
|
|
Returns:
|
|
True if a message was updated, False otherwise.
|
|
"""
|
|
try:
|
|
result = await PrismaChatMessage.prisma().update_many(
|
|
where={
|
|
"sessionId": session_id,
|
|
"toolCallId": tool_call_id,
|
|
},
|
|
data={
|
|
"content": sanitize_string(new_content),
|
|
},
|
|
)
|
|
if result == 0:
|
|
logger.warning(
|
|
f"No message found to update for session {session_id}, "
|
|
f"tool_call_id {tool_call_id}"
|
|
)
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to update tool message for session {session_id}, "
|
|
f"tool_call_id {tool_call_id}: {e}"
|
|
)
|
|
return False
|
|
|
|
|
|
async def set_turn_duration(session_id: str, duration_ms: int) -> None:
|
|
"""Set durationMs on the last assistant message in a session.
|
|
|
|
Also invalidates the Redis session cache so the next GET returns
|
|
the updated duration.
|
|
"""
|
|
last_msg = await PrismaChatMessage.prisma().find_first(
|
|
where={"sessionId": session_id, "role": "assistant"},
|
|
order={"sequence": "desc"},
|
|
)
|
|
if last_msg:
|
|
await PrismaChatMessage.prisma().update(
|
|
where={"id": last_msg.id},
|
|
data={"durationMs": duration_ms},
|
|
)
|
|
# Invalidate cache so the session is re-fetched from DB with durationMs
|
|
await invalidate_session_cache(session_id)
|