From bfd04dcf04285f5aad9b9da22d515d3a3b1dc758 Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Fri, 13 Feb 2026 21:17:32 +0100 Subject: [PATCH] address comments --- .../backend/copilot/completion_consumer.py | 30 ++------- .../backend/copilot/completion_handler.py | 65 +++++++------------ .../backend/backend/copilot/db.py | 39 +++++------ .../backend/copilot/executor/manager.py | 2 +- .../backend/backend/copilot/model.py | 63 ++++++++---------- 5 files changed, 77 insertions(+), 122 deletions(-) diff --git a/autogpt_platform/backend/backend/copilot/completion_consumer.py b/autogpt_platform/backend/backend/copilot/completion_consumer.py index 622760f5d8..6627241f35 100644 --- a/autogpt_platform/backend/backend/copilot/completion_consumer.py +++ b/autogpt_platform/backend/backend/copilot/completion_consumer.py @@ -37,12 +37,10 @@ stale pending messages from dead consumers. import asyncio import logging -import os import uuid from typing import Any import orjson -from prisma import Prisma from pydantic import BaseModel from redis.exceptions import ResponseError @@ -69,8 +67,8 @@ class OperationCompleteMessage(BaseModel): class ChatCompletionConsumer: """Consumer for chat operation completion messages from Redis Streams. - This consumer initializes its own Prisma client in start() to ensure - database operations work correctly within this async context. + Database operations are handled through the chat_db() accessor, which + routes through DatabaseManager RPC when Prisma is not directly connected. Uses Redis consumer groups to allow multiple platform pods to consume messages reliably with automatic redelivery on failure. @@ -79,7 +77,6 @@ class ChatCompletionConsumer: def __init__(self): self._consumer_task: asyncio.Task | None = None self._running = False - self._prisma: Prisma | None = None self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" async def start(self) -> None: @@ -115,16 +112,6 @@ class ChatCompletionConsumer: f"Chat completion consumer started (consumer: {self._consumer_name})" ) - async def _ensure_prisma(self) -> Prisma: - """Lazily initialize Prisma client on first use.""" - if self._prisma is None: - database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") - prisma = Prisma(datasource={"url": database_url}) - await prisma.connect() - self._prisma = prisma - logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") - return self._prisma - async def stop(self) -> None: """Stop the completion consumer.""" self._running = False @@ -137,11 +124,6 @@ class ChatCompletionConsumer: pass self._consumer_task = None - if self._prisma: - await self._prisma.disconnect() - self._prisma = None - logger.info("[COMPLETION] Consumer Prisma client disconnected") - logger.info("Chat completion consumer stopped") async def _consume_messages(self) -> None: @@ -253,7 +235,7 @@ class ChatCompletionConsumer: # XAUTOCLAIM after min_idle_time expires async def _handle_message(self, body: bytes) -> None: - """Handle a completion message using our own Prisma client.""" + """Handle a completion message.""" try: data = orjson.loads(body) message = OperationCompleteMessage(**data) @@ -303,8 +285,7 @@ class ChatCompletionConsumer: message: OperationCompleteMessage, ) -> None: """Handle successful operation completion.""" - prisma = await self._ensure_prisma() - await process_operation_success(task, message.result, prisma) + await process_operation_success(task, message.result) async def _handle_failure( self, @@ -312,8 +293,7 @@ class ChatCompletionConsumer: message: OperationCompleteMessage, ) -> None: """Handle failed operation completion.""" - prisma = await self._ensure_prisma() - await process_operation_failure(task, message.error, prisma) + await process_operation_failure(task, message.error) # Module-level consumer instance diff --git a/autogpt_platform/backend/backend/copilot/completion_handler.py b/autogpt_platform/backend/backend/copilot/completion_handler.py index 905fa2ddba..fd971c5cc2 100644 --- a/autogpt_platform/backend/backend/copilot/completion_handler.py +++ b/autogpt_platform/backend/backend/copilot/completion_handler.py @@ -9,7 +9,8 @@ import logging from typing import Any import orjson -from prisma import Prisma + +from backend.data.db_accessors import chat_db from . import service as chat_service from . import stream_registry @@ -72,48 +73,40 @@ async def _update_tool_message( session_id: str, tool_call_id: str, content: str, - prisma_client: Prisma | None, ) -> None: - """Update tool message in database. + """Update tool message in database using the chat_db accessor. + + Routes through DatabaseManager RPC when Prisma is not directly + connected (e.g. in the CoPilot Executor microservice). Args: session_id: The session ID tool_call_id: The tool call ID to update content: The new content for the message - prisma_client: Optional Prisma client. If None, uses chat_service. Raises: - ToolMessageUpdateError: If the database update fails. The caller should - handle this to avoid marking the task as completed with inconsistent state. + ToolMessageUpdateError: If the database update fails. """ try: - if prisma_client: - # Use provided Prisma client (for consumer with its own connection) - updated_count = await prisma_client.chatmessage.update_many( - where={ - "sessionId": session_id, - "toolCallId": tool_call_id, - }, - data={"content": content}, - ) - # Check if any rows were updated - 0 means message not found - if updated_count == 0: - raise ToolMessageUpdateError( - f"No message found with tool_call_id={tool_call_id} in session {session_id}" - ) - else: - # Use service function (for webhook endpoint) - await chat_service._update_pending_operation( - session_id=session_id, - tool_call_id=tool_call_id, - result=content, + updated = await chat_db().update_tool_message_content( + session_id=session_id, + tool_call_id=tool_call_id, + new_content=content, + ) + if not updated: + raise ToolMessageUpdateError( + f"No message found with tool_call_id=" + f"{tool_call_id} in session {session_id}" ) except ToolMessageUpdateError: raise except Exception as e: - logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + logger.error( + f"[COMPLETION] Failed to update tool message: {e}", + exc_info=True, + ) raise ToolMessageUpdateError( - f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + f"Failed to update tool message for tool call #{tool_call_id}: {e}" ) from e @@ -202,7 +195,6 @@ async def _save_agent_from_result( async def process_operation_success( task: stream_registry.ActiveTask, result: dict | str | None, - prisma_client: Prisma | None = None, ) -> None: """Handle successful operation completion. @@ -212,12 +204,10 @@ async def process_operation_success( Args: task: The active task that completed result: The result data from the operation - prisma_client: Optional Prisma client for database operations. - If None, uses chat_service._update_pending_operation instead. Raises: - ToolMessageUpdateError: If the database update fails. The task will be - marked as failed instead of completed to avoid inconsistent state. + ToolMessageUpdateError: If the database update fails. The task + will be marked as failed instead of completed. """ # For agent generation tools, save the agent to library if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): @@ -250,7 +240,6 @@ async def process_operation_success( session_id=task.session_id, tool_call_id=task.tool_call_id, content=result_str, - prisma_client=prisma_client, ) except ToolMessageUpdateError: # DB update failed - mark task as failed to avoid inconsistent state @@ -293,18 +282,15 @@ async def process_operation_success( async def process_operation_failure( task: stream_registry.ActiveTask, error: str | None, - prisma_client: Prisma | None = None, ) -> None: """Handle failed operation completion. - Publishes the error to the stream registry, updates the database with - the error response, and marks the task as failed. + Publishes the error to the stream registry, updates the database + with the error response, and marks the task as failed. Args: task: The active task that failed error: The error message from the operation - prisma_client: Optional Prisma client for database operations. - If None, uses chat_service._update_pending_operation instead. """ error_msg = error or "Operation failed" @@ -325,7 +311,6 @@ async def process_operation_failure( session_id=task.session_id, tool_call_id=task.tool_call_id, content=error_response.model_dump_json(), - prisma_client=prisma_client, ) except ToolMessageUpdateError: # DB update failed - log but continue with cleanup diff --git a/autogpt_platform/backend/backend/copilot/db.py b/autogpt_platform/backend/backend/copilot/db.py index f94d959f05..bce2ed9627 100644 --- a/autogpt_platform/backend/backend/copilot/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -17,26 +17,24 @@ from prisma.types import ( from backend.data import db from backend.util.json import SafeJson +from .model import ChatMessage, ChatSession + logger = logging.getLogger(__name__) -async def get_chat_session(session_id: str) -> PrismaChatSession | None: +async def get_chat_session(session_id: str) -> ChatSession | None: """Get a chat session by ID from the database.""" session = await PrismaChatSession.prisma().find_unique( where={"id": session_id}, - include={"Messages": True}, + include={"Messages": {"order_by": {"sequence": "asc"}}}, ) - if session and session.Messages: - # Sort messages by sequence in Python - Prisma Python client doesn't support - # order_by in include clauses (unlike Prisma JS), so we sort after fetching - session.Messages.sort(key=lambda m: m.sequence) - return session + return ChatSession.from_db(session) if session else None async def create_chat_session( session_id: str, user_id: str, -) -> PrismaChatSession: +) -> ChatSession: """Create a new chat session in the database.""" data = ChatSessionCreateInput( id=session_id, @@ -45,7 +43,8 @@ async def create_chat_session( successfulAgentRuns=SafeJson({}), successfulAgentSchedules=SafeJson({}), ) - return await PrismaChatSession.prisma().create(data=data) + prisma_session = await PrismaChatSession.prisma().create(data=data) + return ChatSession.from_db(prisma_session) async def update_chat_session( @@ -56,7 +55,7 @@ async def update_chat_session( total_prompt_tokens: int | None = None, total_completion_tokens: int | None = None, title: str | None = None, -) -> PrismaChatSession | None: +) -> ChatSession | None: """Update a chat session's metadata.""" data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)} @@ -76,12 +75,9 @@ async def update_chat_session( session = await PrismaChatSession.prisma().update( where={"id": session_id}, data=data, - include={"Messages": True}, + include={"Messages": {"order_by": {"sequence": "asc"}}}, ) - if session and session.Messages: - # Sort in Python - Prisma Python doesn't support order_by in include clauses - session.Messages.sort(key=lambda m: m.sequence) - return session + return ChatSession.from_db(session) if session else None async def add_chat_message( @@ -94,7 +90,7 @@ async def add_chat_message( refusal: str | None = None, tool_calls: list[dict[str, Any]] | None = None, function_call: dict[str, Any] | None = None, -) -> PrismaChatMessage: +) -> ChatMessage: """Add a message to a chat session.""" # Build input dict dynamically rather than using ChatMessageCreateInput directly # because Prisma's TypedDict validation rejects optional fields set to None. @@ -129,14 +125,14 @@ async def add_chat_message( ), PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)), ) - return message + return ChatMessage.from_db(message) async def add_chat_messages_batch( session_id: str, messages: list[dict[str, Any]], start_sequence: int, -) -> list[PrismaChatMessage]: +) -> list[ChatMessage]: """Add multiple messages to a chat session in a batch. Uses a transaction for atomicity - if any message creation fails, @@ -187,21 +183,22 @@ async def add_chat_messages_batch( data={"updatedAt": datetime.now(UTC)}, ) - return created_messages + return [ChatMessage.from_db(m) for m in created_messages] async def get_user_chat_sessions( user_id: str, limit: int = 50, offset: int = 0, -) -> list[PrismaChatSession]: +) -> list[ChatSession]: """Get chat sessions for a user, ordered by most recent.""" - return await PrismaChatSession.prisma().find_many( + prisma_sessions = await PrismaChatSession.prisma().find_many( where={"userId": user_id}, order={"updatedAt": "desc"}, take=limit, skip=offset, ) + return [ChatSession.from_db(s) for s in prisma_sessions] async def get_user_session_count(user_id: str) -> int: diff --git a/autogpt_platform/backend/backend/copilot/executor/manager.py b/autogpt_platform/backend/backend/copilot/executor/manager.py index 0b7235ca13..d3b9d82b39 100644 --- a/autogpt_platform/backend/backend/copilot/executor/manager.py +++ b/autogpt_platform/backend/backend/copilot/executor/manager.py @@ -138,7 +138,7 @@ class CoPilotExecutor(AppProcess): # Refresh cluster locks periodically current_time = time.monotonic() if current_time - last_refresh >= lock_refresh_interval: - for lock in self._task_locks.values(): + for lock in list(self._task_locks.values()): try: lock.refresh() except Exception as e: diff --git a/autogpt_platform/backend/backend/copilot/model.py b/autogpt_platform/backend/backend/copilot/model.py index b48e471a21..f909c8de39 100644 --- a/autogpt_platform/backend/backend/copilot/model.py +++ b/autogpt_platform/backend/backend/copilot/model.py @@ -55,6 +55,19 @@ class ChatMessage(BaseModel): tool_calls: list[dict] | None = None function_call: dict | None = None + @staticmethod + def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage": + """Convert a Prisma ChatMessage to a Pydantic ChatMessage.""" + return ChatMessage( + role=prisma_message.role, + content=prisma_message.content, + name=prisma_message.name, + tool_call_id=prisma_message.toolCallId, + refusal=prisma_message.refusal, + tool_calls=_parse_json_field(prisma_message.toolCalls), + function_call=_parse_json_field(prisma_message.functionCall), + ) + class Usage(BaseModel): prompt_tokens: int @@ -108,26 +121,8 @@ class ChatSession(BaseModel): ) @staticmethod - def from_db( - prisma_session: PrismaChatSession, - prisma_messages: list[PrismaChatMessage] | None = None, - ) -> "ChatSession": - """Convert Prisma models to Pydantic ChatSession.""" - messages = [] - if prisma_messages: - for msg in prisma_messages: - messages.append( - ChatMessage( - role=msg.role, - content=msg.content, - name=msg.name, - tool_call_id=msg.toolCallId, - refusal=msg.refusal, - tool_calls=_parse_json_field(msg.toolCalls), - function_call=_parse_json_field(msg.functionCall), - ) - ) - + def from_db(prisma_session: PrismaChatSession) -> "ChatSession": + """Convert Prisma ChatSession to Pydantic ChatSession.""" # Parse JSON fields from Prisma credentials = _parse_json_field(prisma_session.credentials, default={}) successful_agent_runs = _parse_json_field( @@ -153,7 +148,11 @@ class ChatSession(BaseModel): session_id=prisma_session.id, user_id=prisma_session.userId, title=prisma_session.title, - messages=messages, + messages=( + [ChatMessage.from_db(m) for m in prisma_session.Messages] + if prisma_session.Messages + else [] + ), usage=usage, credentials=credentials, started_at=prisma_session.createdAt, @@ -408,19 +407,18 @@ async def _get_session_from_cache(session_id: str) -> ChatSession | None: async def _get_session_from_db(session_id: str) -> ChatSession | None: """Get a chat session from the database.""" - prisma_session = await chat_db().get_chat_session(session_id) - if not prisma_session: + session = await chat_db().get_chat_session(session_id) + if not session: return None - messages = prisma_session.Messages logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" + f"Loaded session {session_id} from DB: " + f"has_messages={bool(session.messages)}, " + f"message_count={len(session.messages)}, " + f"roles={[m.role for m in session.messages]}" ) - return ChatSession.from_db(prisma_session, messages) + return session async def upsert_chat_session( @@ -617,14 +615,9 @@ async def get_user_sessions( number of sessions for the user (not just the current page). """ db = chat_db() - prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset) + sessions = await db.get_user_chat_sessions(user_id, limit, offset) total_count = await db.get_user_session_count(user_id) - sessions = [] - for prisma_session in prisma_sessions: - # Convert without messages for listing (lighter weight) - sessions.append(ChatSession.from_db(prisma_session, None)) - return sessions, total_count