mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
## Summary Agent generation (`create_agent`, `edit_agent`) can take 1-5 minutes. Previously, if the user closed their browser tab during this time: 1. The SSE connection would die 2. The tool execution would be cancelled via `CancelledError` 3. The result would be lost - even if the agent-generator service completed successfully This PR ensures long-running tool operations survive SSE disconnections. ### Changes 🏗️ **Backend:** - **base.py**: Added `is_long_running` property to `BaseTool` for tools to opt-in to background execution - **create_agent.py / edit_agent.py**: Set `is_long_running = True` - **models.py**: Added `OperationStartedResponse`, `OperationPendingResponse`, `OperationInProgressResponse` types - **service.py**: Modified `_yield_tool_call()` to: - Check if tool is `is_long_running` - Save "pending" message to chat history immediately - Spawn background task that runs independently of SSE - Return `operation_started` immediately (don't wait) - Update chat history with result when background task completes - Track running operations for idempotency (prevents duplicate ops on refresh) - **db.py**: Added `update_tool_message_content()` to update pending messages - **model.py**: Added `invalidate_session_cache()` to clear Redis after background completion **Frontend:** - **useChatMessage.ts**: Added operation message types - **helpers.ts**: Handle `operation_started`, `operation_pending`, `operation_in_progress` response types - **PendingOperationWidget**: New component to display operation status with spinner - **ChatMessage.tsx**: Render `PendingOperationWidget` for operation messages ### How It Works ``` User Request → Save "pending" message → Spawn background task → Return immediately ↓ Task runs independently of SSE ↓ On completion: Update message in chat history ↓ User refreshes → Loads history → Sees result ``` ### User Experience 1. User requests agent creation 2. Sees "Agent creation started. You can close this tab - check your library in a few minutes." 3. Can close browser tab safely 4. When they return, chat shows the completed result (or error) ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] pyright passes (0 errors) - [x] TypeScript checks pass - [x] Formatters applied ### Test Plan 1. Start agent creation in copilot 2. Close browser tab immediately after seeing "operation_started" 3. Wait 2-3 minutes 4. Reopen chat 5. Verify: Chat history shows completion message and agent appears in library --------- Co-authored-by: Ubbe <hi@ubbe.dev>
292 lines
9.7 KiB
Python
292 lines
9.7 KiB
Python
"""Database operations for chat sessions."""
|
|
|
|
import asyncio
|
|
import logging
|
|
from datetime import UTC, datetime
|
|
from typing import Any, cast
|
|
|
|
from prisma.models import ChatMessage as PrismaChatMessage
|
|
from prisma.models import ChatSession as PrismaChatSession
|
|
from prisma.types import (
|
|
ChatMessageCreateInput,
|
|
ChatSessionCreateInput,
|
|
ChatSessionUpdateInput,
|
|
ChatSessionWhereInput,
|
|
)
|
|
|
|
from backend.data.db import transaction
|
|
from backend.util.json import SafeJson
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def get_chat_session(session_id: str) -> PrismaChatSession | None:
|
|
"""Get a chat session by ID from the database."""
|
|
session = await PrismaChatSession.prisma().find_unique(
|
|
where={"id": session_id},
|
|
include={"Messages": True},
|
|
)
|
|
if session and session.Messages:
|
|
# Sort messages by sequence in Python - Prisma Python client doesn't support
|
|
# order_by in include clauses (unlike Prisma JS), so we sort after fetching
|
|
session.Messages.sort(key=lambda m: m.sequence)
|
|
return session
|
|
|
|
|
|
async def create_chat_session(
|
|
session_id: str,
|
|
user_id: str,
|
|
) -> PrismaChatSession:
|
|
"""Create a new chat session in the database."""
|
|
data = ChatSessionCreateInput(
|
|
id=session_id,
|
|
userId=user_id,
|
|
credentials=SafeJson({}),
|
|
successfulAgentRuns=SafeJson({}),
|
|
successfulAgentSchedules=SafeJson({}),
|
|
)
|
|
return await PrismaChatSession.prisma().create(
|
|
data=data,
|
|
include={"Messages": True},
|
|
)
|
|
|
|
|
|
async def update_chat_session(
|
|
session_id: str,
|
|
credentials: dict[str, Any] | None = None,
|
|
successful_agent_runs: dict[str, Any] | None = None,
|
|
successful_agent_schedules: dict[str, Any] | None = None,
|
|
total_prompt_tokens: int | None = None,
|
|
total_completion_tokens: int | None = None,
|
|
title: str | None = None,
|
|
) -> PrismaChatSession | None:
|
|
"""Update a chat session's metadata."""
|
|
data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)}
|
|
|
|
if credentials is not None:
|
|
data["credentials"] = SafeJson(credentials)
|
|
if successful_agent_runs is not None:
|
|
data["successfulAgentRuns"] = SafeJson(successful_agent_runs)
|
|
if successful_agent_schedules is not None:
|
|
data["successfulAgentSchedules"] = SafeJson(successful_agent_schedules)
|
|
if total_prompt_tokens is not None:
|
|
data["totalPromptTokens"] = total_prompt_tokens
|
|
if total_completion_tokens is not None:
|
|
data["totalCompletionTokens"] = total_completion_tokens
|
|
if title is not None:
|
|
data["title"] = title
|
|
|
|
session = await PrismaChatSession.prisma().update(
|
|
where={"id": session_id},
|
|
data=data,
|
|
include={"Messages": True},
|
|
)
|
|
if session and session.Messages:
|
|
# Sort in Python - Prisma Python doesn't support order_by in include clauses
|
|
session.Messages.sort(key=lambda m: m.sequence)
|
|
return session
|
|
|
|
|
|
async def add_chat_message(
|
|
session_id: str,
|
|
role: str,
|
|
sequence: int,
|
|
content: str | None = None,
|
|
name: str | None = None,
|
|
tool_call_id: str | None = None,
|
|
refusal: str | None = None,
|
|
tool_calls: list[dict[str, Any]] | None = None,
|
|
function_call: dict[str, Any] | None = None,
|
|
) -> PrismaChatMessage:
|
|
"""Add a message to a chat session."""
|
|
# Build input dict dynamically rather than using ChatMessageCreateInput directly
|
|
# because Prisma's TypedDict validation rejects optional fields set to None.
|
|
# We only include fields that have values, then cast at the end.
|
|
data: dict[str, Any] = {
|
|
"Session": {"connect": {"id": session_id}},
|
|
"role": role,
|
|
"sequence": sequence,
|
|
}
|
|
|
|
# Add optional string fields
|
|
if content is not None:
|
|
data["content"] = content
|
|
if name is not None:
|
|
data["name"] = name
|
|
if tool_call_id is not None:
|
|
data["toolCallId"] = tool_call_id
|
|
if refusal is not None:
|
|
data["refusal"] = refusal
|
|
|
|
# Add optional JSON fields only when they have values
|
|
if tool_calls is not None:
|
|
data["toolCalls"] = SafeJson(tool_calls)
|
|
if function_call is not None:
|
|
data["functionCall"] = SafeJson(function_call)
|
|
|
|
# Run message create and session timestamp update in parallel for lower latency
|
|
_, message = await asyncio.gather(
|
|
PrismaChatSession.prisma().update(
|
|
where={"id": session_id},
|
|
data={"updatedAt": datetime.now(UTC)},
|
|
),
|
|
PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)),
|
|
)
|
|
return message
|
|
|
|
|
|
async def add_chat_messages_batch(
|
|
session_id: str,
|
|
messages: list[dict[str, Any]],
|
|
start_sequence: int,
|
|
) -> list[PrismaChatMessage]:
|
|
"""Add multiple messages to a chat session in a batch.
|
|
|
|
Uses a transaction for atomicity - if any message creation fails,
|
|
the entire batch is rolled back.
|
|
"""
|
|
if not messages:
|
|
return []
|
|
|
|
created_messages = []
|
|
|
|
async with transaction() as tx:
|
|
for i, msg in enumerate(messages):
|
|
# Build input dict dynamically rather than using ChatMessageCreateInput
|
|
# directly because Prisma's TypedDict validation rejects optional fields
|
|
# set to None. We only include fields that have values, then cast.
|
|
data: dict[str, Any] = {
|
|
"Session": {"connect": {"id": session_id}},
|
|
"role": msg["role"],
|
|
"sequence": start_sequence + i,
|
|
}
|
|
|
|
# Add optional string fields
|
|
if msg.get("content") is not None:
|
|
data["content"] = msg["content"]
|
|
if msg.get("name") is not None:
|
|
data["name"] = msg["name"]
|
|
if msg.get("tool_call_id") is not None:
|
|
data["toolCallId"] = msg["tool_call_id"]
|
|
if msg.get("refusal") is not None:
|
|
data["refusal"] = msg["refusal"]
|
|
|
|
# Add optional JSON fields only when they have values
|
|
if msg.get("tool_calls") is not None:
|
|
data["toolCalls"] = SafeJson(msg["tool_calls"])
|
|
if msg.get("function_call") is not None:
|
|
data["functionCall"] = SafeJson(msg["function_call"])
|
|
|
|
created = await PrismaChatMessage.prisma(tx).create(
|
|
data=cast(ChatMessageCreateInput, data)
|
|
)
|
|
created_messages.append(created)
|
|
|
|
# Update session's updatedAt timestamp within the same transaction.
|
|
# Note: Token usage (total_prompt_tokens, total_completion_tokens) is updated
|
|
# separately via update_chat_session() after streaming completes.
|
|
await PrismaChatSession.prisma(tx).update(
|
|
where={"id": session_id},
|
|
data={"updatedAt": datetime.now(UTC)},
|
|
)
|
|
|
|
return created_messages
|
|
|
|
|
|
async def get_user_chat_sessions(
|
|
user_id: str,
|
|
limit: int = 50,
|
|
offset: int = 0,
|
|
) -> list[PrismaChatSession]:
|
|
"""Get chat sessions for a user, ordered by most recent."""
|
|
return await PrismaChatSession.prisma().find_many(
|
|
where={"userId": user_id},
|
|
order={"updatedAt": "desc"},
|
|
take=limit,
|
|
skip=offset,
|
|
)
|
|
|
|
|
|
async def get_user_session_count(user_id: str) -> int:
|
|
"""Get the total number of chat sessions for a user."""
|
|
return await PrismaChatSession.prisma().count(where={"userId": user_id})
|
|
|
|
|
|
async def delete_chat_session(session_id: str, user_id: str | None = None) -> bool:
|
|
"""Delete a chat session and all its messages.
|
|
|
|
Args:
|
|
session_id: The session ID to delete.
|
|
user_id: If provided, validates that the session belongs to this user
|
|
before deletion. This prevents unauthorized deletion of other
|
|
users' sessions.
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise.
|
|
"""
|
|
try:
|
|
# Build typed where clause with optional user_id validation
|
|
where_clause: ChatSessionWhereInput = {"id": session_id}
|
|
if user_id is not None:
|
|
where_clause["userId"] = user_id
|
|
|
|
result = await PrismaChatSession.prisma().delete_many(where=where_clause)
|
|
if result == 0:
|
|
logger.warning(
|
|
f"No session deleted for {session_id} "
|
|
f"(user_id validation: {user_id is not None})"
|
|
)
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Failed to delete chat session {session_id}: {e}")
|
|
return False
|
|
|
|
|
|
async def get_chat_session_message_count(session_id: str) -> int:
|
|
"""Get the number of messages in a chat session."""
|
|
count = await PrismaChatMessage.prisma().count(where={"sessionId": session_id})
|
|
return count
|
|
|
|
|
|
async def update_tool_message_content(
|
|
session_id: str,
|
|
tool_call_id: str,
|
|
new_content: str,
|
|
) -> bool:
|
|
"""Update the content of a tool message in chat history.
|
|
|
|
Used by background tasks to update pending operation messages with final results.
|
|
|
|
Args:
|
|
session_id: The chat session ID.
|
|
tool_call_id: The tool call ID to find the message.
|
|
new_content: The new content to set.
|
|
|
|
Returns:
|
|
True if a message was updated, False otherwise.
|
|
"""
|
|
try:
|
|
result = await PrismaChatMessage.prisma().update_many(
|
|
where={
|
|
"sessionId": session_id,
|
|
"toolCallId": tool_call_id,
|
|
},
|
|
data={
|
|
"content": new_content,
|
|
},
|
|
)
|
|
if result == 0:
|
|
logger.warning(
|
|
f"No message found to update for session {session_id}, "
|
|
f"tool_call_id {tool_call_id}"
|
|
)
|
|
return False
|
|
return True
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Failed to update tool message for session {session_id}, "
|
|
f"tool_call_id {tool_call_id}: {e}"
|
|
)
|
|
return False
|