diff --git a/autogpt_platform/backend/backend/api/conftest.py b/autogpt_platform/backend/backend/api/conftest.py index d471a3d536..7e8cc1aec6 100644 --- a/autogpt_platform/backend/backend/api/conftest.py +++ b/autogpt_platform/backend/backend/api/conftest.py @@ -1,4 +1,9 @@ -"""Common test fixtures for server tests.""" +"""Common test fixtures for server tests. + +Note: Common fixtures like test_user_id, admin_user_id, target_user_id, +setup_test_user, and setup_admin_user are defined in the parent conftest.py +(backend/conftest.py) and are available here automatically. +""" import pytest from pytest_snapshot.plugin import Snapshot @@ -11,54 +16,6 @@ def configured_snapshot(snapshot: Snapshot) -> Snapshot: return snapshot -@pytest.fixture -def test_user_id() -> str: - """Test user ID fixture.""" - return "3e53486c-cf57-477e-ba2a-cb02dc828e1a" - - -@pytest.fixture -def admin_user_id() -> str: - """Admin user ID fixture.""" - return "4e53486c-cf57-477e-ba2a-cb02dc828e1b" - - -@pytest.fixture -def target_user_id() -> str: - """Target user ID fixture.""" - return "5e53486c-cf57-477e-ba2a-cb02dc828e1c" - - -@pytest.fixture -async def setup_test_user(test_user_id): - """Create test user in database before tests.""" - from backend.data.user import get_or_create_user - - # Create the test user in the database using JWT token format - user_data = { - "sub": test_user_id, - "email": "test@example.com", - "user_metadata": {"name": "Test User"}, - } - await get_or_create_user(user_data) - return test_user_id - - -@pytest.fixture -async def setup_admin_user(admin_user_id): - """Create admin user in database before tests.""" - from backend.data.user import get_or_create_user - - # Create the admin user in the database using JWT token format - user_data = { - "sub": admin_user_id, - "email": "test-admin@example.com", - "user_metadata": {"name": "Test Admin"}, - } - await get_or_create_user(user_data) - return admin_user_id - - @pytest.fixture def mock_jwt_user(test_user_id): """Provide mock JWT payload for regular user testing.""" diff --git a/autogpt_platform/backend/backend/api/external/v1/tools.py b/autogpt_platform/backend/backend/api/external/v1/tools.py index 07734dd0c9..7b8b327919 100644 --- a/autogpt_platform/backend/backend/api/external/v1/tools.py +++ b/autogpt_platform/backend/backend/api/external/v1/tools.py @@ -15,9 +15,9 @@ from prisma.enums import APIKeyPermission from pydantic import BaseModel, Field from backend.api.external.middleware import require_permission -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools import find_agent_tool, run_agent_tool -from backend.api.features.chat.tools.models import ToolResponseBase +from backend.copilot.model import ChatSession +from backend.copilot.tools import find_agent_tool, run_agent_tool +from backend.copilot.tools.models import ToolResponseBase from backend.data.auth.base import APIAuthorizationInfo logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index d838520c98..2fd7d29319 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -11,14 +11,15 @@ from fastapi import APIRouter, Depends, Header, HTTPException, Query, Response, from fastapi.responses import StreamingResponse from pydantic import BaseModel -from backend.util.exceptions import NotFoundError -from backend.util.feature_flag import Flag, is_feature_enabled - -from . import service as chat_service -from . import stream_registry -from .completion_handler import process_operation_failure, process_operation_success -from .config import ChatConfig -from .model import ( +from backend.copilot import service as chat_service +from backend.copilot import stream_registry +from backend.copilot.completion_handler import ( + process_operation_failure, + process_operation_success, +) +from backend.copilot.config import ChatConfig +from backend.copilot.executor.utils import enqueue_copilot_task +from backend.copilot.model import ( ChatMessage, ChatSession, append_and_save_message, @@ -27,9 +28,8 @@ from .model import ( get_chat_session, get_user_sessions, ) -from .response_model import StreamError, StreamFinish, StreamHeartbeat, StreamStart -from .sdk import service as sdk_service -from .tools.models import ( +from backend.copilot.response_model import StreamError, StreamFinish, StreamHeartbeat +from backend.copilot.tools.models import ( AgentDetailsResponse, AgentOutputResponse, AgentPreviewResponse, @@ -52,7 +52,8 @@ from .tools.models import ( SetupRequirementsResponse, UnderstandingUpdatedResponse, ) -from .tracking import track_user_message +from backend.copilot.tracking import track_user_message +from backend.util.exceptions import NotFoundError config = ChatConfig() @@ -354,7 +355,7 @@ async def stream_chat_post( f"user={user_id}, message_len={len(request.message)}", extra={"json_fields": log_meta}, ) - session = await _validate_and_get_session(session_id, user_id) + await _validate_and_get_session(session_id, user_id) logger.info( f"[TIMING] session validated in {(time.perf_counter() - stream_start_time) * 1000:.1f}ms", extra={ @@ -381,7 +382,7 @@ async def stream_chat_post( message_length=len(request.message), ) logger.info(f"[STREAM] Saving user message to session {session_id}") - session = await append_and_save_message(session_id, message) + await append_and_save_message(session_id, message) logger.info(f"[STREAM] User message saved for session {session_id}") # Create a task in the stream registry for reconnection support @@ -408,125 +409,19 @@ async def stream_chat_post( }, ) - # Background task that runs the AI generation independently of SSE connection - async def run_ai_generation(): - import time as time_module + await enqueue_copilot_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + operation_id=operation_id, + message=request.message, + is_user_message=request.is_user_message, + context=request.context, + ) - gen_start_time = time_module.perf_counter() - logger.info( - f"[TIMING] run_ai_generation STARTED, task={task_id}, session={session_id}, user={user_id}", - extra={"json_fields": log_meta}, - ) - first_chunk_time, ttfc = None, None - chunk_count = 0 - try: - # Emit a start event with task_id for reconnection - start_chunk = StreamStart(messageId=task_id, taskId=task_id) - await stream_registry.publish_chunk(task_id, start_chunk) - logger.info( - f"[TIMING] StreamStart published at {(time_module.perf_counter() - gen_start_time) * 1000:.1f}ms", - extra={ - "json_fields": { - **log_meta, - "elapsed_ms": (time_module.perf_counter() - gen_start_time) - * 1000, - } - }, - ) - - # Choose service based on LaunchDarkly flag (falls back to config default) - use_sdk = await is_feature_enabled( - Flag.COPILOT_SDK, - user_id or "anonymous", - default=config.use_claude_agent_sdk, - ) - stream_fn = ( - sdk_service.stream_chat_completion_sdk - if use_sdk - else chat_service.stream_chat_completion - ) - logger.info( - f"[TIMING] Calling {'sdk' if use_sdk else 'standard'} stream_chat_completion", - extra={"json_fields": log_meta}, - ) - # Pass message=None since we already added it to the session above - async for chunk in stream_fn( - session_id, - None, # Message already in session - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass session with message already added - context=request.context, - ): - # Skip duplicate StreamStart — we already published one above - if isinstance(chunk, StreamStart): - continue - chunk_count += 1 - if first_chunk_time is None: - first_chunk_time = time_module.perf_counter() - ttfc = first_chunk_time - gen_start_time - logger.info( - f"[TIMING] FIRST AI CHUNK at {ttfc:.2f}s, type={type(chunk).__name__}", - extra={ - "json_fields": { - **log_meta, - "chunk_type": type(chunk).__name__, - "time_to_first_chunk_ms": ttfc * 1000, - } - }, - ) - # Write to Redis (subscribers will receive via XREAD) - await stream_registry.publish_chunk(task_id, chunk) - - gen_end_time = time_module.perf_counter() - total_time = (gen_end_time - gen_start_time) * 1000 - logger.info( - f"[TIMING] run_ai_generation FINISHED in {total_time / 1000:.1f}s; " - f"task={task_id}, session={session_id}, " - f"ttfc={ttfc or -1:.2f}s, n_chunks={chunk_count}", - extra={ - "json_fields": { - **log_meta, - "total_time_ms": total_time, - "time_to_first_chunk_ms": ( - ttfc * 1000 if ttfc is not None else None - ), - "n_chunks": chunk_count, - } - }, - ) - await stream_registry.mark_task_completed(task_id, "completed") - except Exception as e: - elapsed = time_module.perf_counter() - gen_start_time - logger.error( - f"[TIMING] run_ai_generation ERROR after {elapsed:.2f}s: {e}", - extra={ - "json_fields": { - **log_meta, - "elapsed_ms": elapsed * 1000, - "error": str(e), - } - }, - ) - # Publish a StreamError so the frontend can display an error message - try: - await stream_registry.publish_chunk( - task_id, - StreamError( - errorText="An error occurred. Please try again.", - code="stream_error", - ), - ) - except Exception: - pass # Best-effort; mark_task_completed will publish StreamFinish - await stream_registry.mark_task_completed(task_id, "failed") - - # Start the AI generation in a background task - bg_task = asyncio.create_task(run_ai_generation()) - await stream_registry.set_task_asyncio_task(task_id, bg_task) setup_time = (time.perf_counter() - stream_start_time) * 1000 logger.info( - f"[TIMING] Background task started, setup={setup_time:.1f}ms", + f"[TIMING] Task enqueued to RabbitMQ, setup={setup_time:.1f}ms", extra={"json_fields": {**log_meta, "setup_time_ms": setup_time}}, ) diff --git a/autogpt_platform/backend/backend/api/features/workspace/routes.py b/autogpt_platform/backend/backend/api/features/workspace/routes.py index b6d0c84572..974465b2c0 100644 --- a/autogpt_platform/backend/backend/api/features/workspace/routes.py +++ b/autogpt_platform/backend/backend/api/features/workspace/routes.py @@ -11,7 +11,7 @@ import fastapi from autogpt_libs.auth.dependencies import get_user_id, requires_user from fastapi.responses import Response -from backend.data.workspace import get_workspace, get_workspace_file +from backend.data.workspace import WorkspaceFile, get_workspace, get_workspace_file from backend.util.workspace_storage import get_workspace_storage @@ -44,11 +44,11 @@ router = fastapi.APIRouter( ) -def _create_streaming_response(content: bytes, file) -> Response: +def _create_streaming_response(content: bytes, file: WorkspaceFile) -> Response: """Create a streaming response for file content.""" return Response( content=content, - media_type=file.mimeType, + media_type=file.mime_type, headers={ "Content-Disposition": _sanitize_filename_for_header(file.name), "Content-Length": str(len(content)), @@ -56,7 +56,7 @@ def _create_streaming_response(content: bytes, file) -> Response: ) -async def _create_file_download_response(file) -> Response: +async def _create_file_download_response(file: WorkspaceFile) -> Response: """ Create a download response for a workspace file. @@ -66,33 +66,33 @@ async def _create_file_download_response(file) -> Response: storage = await get_workspace_storage() # For local storage, stream the file directly - if file.storagePath.startswith("local://"): - content = await storage.retrieve(file.storagePath) + if file.storage_path.startswith("local://"): + content = await storage.retrieve(file.storage_path) return _create_streaming_response(content, file) # For GCS, try to redirect to signed URL, fall back to streaming try: - url = await storage.get_download_url(file.storagePath, expires_in=300) + url = await storage.get_download_url(file.storage_path, expires_in=300) # If we got back an API path (fallback), stream directly instead if url.startswith("/api/"): - content = await storage.retrieve(file.storagePath) + content = await storage.retrieve(file.storage_path) return _create_streaming_response(content, file) return fastapi.responses.RedirectResponse(url=url, status_code=302) except Exception as e: # Log the signed URL failure with context logger.error( f"Failed to get signed URL for file {file.id} " - f"(storagePath={file.storagePath}): {e}", + f"(storagePath={file.storage_path}): {e}", exc_info=True, ) # Fall back to streaming directly from GCS try: - content = await storage.retrieve(file.storagePath) + content = await storage.retrieve(file.storage_path) return _create_streaming_response(content, file) except Exception as fallback_error: logger.error( f"Fallback streaming also failed for file {file.id} " - f"(storagePath={file.storagePath}): {fallback_error}", + f"(storagePath={file.storage_path}): {fallback_error}", exc_info=True, ) raise diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index aed348755b..f37f28dd7c 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -41,11 +41,11 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings -from backend.api.features.chat.completion_consumer import ( +from backend.blocks.llm import DEFAULT_LLM_MODEL +from backend.copilot.completion_consumer import ( start_completion_consumer, stop_completion_consumer, ) -from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName from backend.monitoring.instrumentation import instrument_fastapi diff --git a/autogpt_platform/backend/backend/app.py b/autogpt_platform/backend/backend/app.py index 0afed130ed..90a218d2e5 100644 --- a/autogpt_platform/backend/backend/app.py +++ b/autogpt_platform/backend/backend/app.py @@ -38,7 +38,9 @@ def main(**kwargs): from backend.api.rest_api import AgentServer from backend.api.ws_api import WebsocketServer - from backend.executor import DatabaseManager, ExecutionManager, Scheduler + from backend.copilot.executor.manager import CoPilotExecutor + from backend.data.db_manager import DatabaseManager + from backend.executor import ExecutionManager, Scheduler from backend.notifications import NotificationManager run_processes( @@ -48,6 +50,7 @@ def main(**kwargs): WebsocketServer(), AgentServer(), ExecutionManager(), + CoPilotExecutor(), **kwargs, ) diff --git a/autogpt_platform/backend/backend/conftest.py b/autogpt_platform/backend/backend/conftest.py index 57481e4b85..4fc6693f5e 100644 --- a/autogpt_platform/backend/backend/conftest.py +++ b/autogpt_platform/backend/backend/conftest.py @@ -1,6 +1,7 @@ import logging import os +import pytest import pytest_asyncio from dotenv import load_dotenv @@ -27,6 +28,54 @@ async def server(): yield server +@pytest.fixture +def test_user_id() -> str: + """Test user ID fixture.""" + return "3e53486c-cf57-477e-ba2a-cb02dc828e1a" + + +@pytest.fixture +def admin_user_id() -> str: + """Admin user ID fixture.""" + return "4e53486c-cf57-477e-ba2a-cb02dc828e1b" + + +@pytest.fixture +def target_user_id() -> str: + """Target user ID fixture.""" + return "5e53486c-cf57-477e-ba2a-cb02dc828e1c" + + +@pytest.fixture +async def setup_test_user(test_user_id): + """Create test user in database before tests.""" + from backend.data.user import get_or_create_user + + # Create the test user in the database using JWT token format + user_data = { + "sub": test_user_id, + "email": "test@example.com", + "user_metadata": {"name": "Test User"}, + } + await get_or_create_user(user_data) + return test_user_id + + +@pytest.fixture +async def setup_admin_user(admin_user_id): + """Create admin user in database before tests.""" + from backend.data.user import get_or_create_user + + # Create the admin user in the database using JWT token format + user_data = { + "sub": admin_user_id, + "email": "test-admin@example.com", + "user_metadata": {"name": "Test Admin"}, + } + await get_or_create_user(user_data) + return admin_user_id + + @pytest_asyncio.fixture(scope="session", loop_scope="session", autouse=True) async def graph_cleanup(server): created_graph_ids = [] diff --git a/autogpt_platform/backend/backend/copilot/__init__.py b/autogpt_platform/backend/backend/copilot/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/__init__.py @@ -0,0 +1 @@ + diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/copilot/completion_consumer.py similarity index 91% rename from autogpt_platform/backend/backend/api/features/chat/completion_consumer.py rename to autogpt_platform/backend/backend/copilot/completion_consumer.py index f447d46bd7..6627241f35 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py +++ b/autogpt_platform/backend/backend/copilot/completion_consumer.py @@ -37,12 +37,10 @@ stale pending messages from dead consumers. import asyncio import logging -import os import uuid from typing import Any import orjson -from prisma import Prisma from pydantic import BaseModel from redis.exceptions import ResponseError @@ -69,8 +67,8 @@ class OperationCompleteMessage(BaseModel): class ChatCompletionConsumer: """Consumer for chat operation completion messages from Redis Streams. - This consumer initializes its own Prisma client in start() to ensure - database operations work correctly within this async context. + Database operations are handled through the chat_db() accessor, which + routes through DatabaseManager RPC when Prisma is not directly connected. Uses Redis consumer groups to allow multiple platform pods to consume messages reliably with automatic redelivery on failure. @@ -79,7 +77,6 @@ class ChatCompletionConsumer: def __init__(self): self._consumer_task: asyncio.Task | None = None self._running = False - self._prisma: Prisma | None = None self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" async def start(self) -> None: @@ -115,15 +112,6 @@ class ChatCompletionConsumer: f"Chat completion consumer started (consumer: {self._consumer_name})" ) - async def _ensure_prisma(self) -> Prisma: - """Lazily initialize Prisma client on first use.""" - if self._prisma is None: - database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") - self._prisma = Prisma(datasource={"url": database_url}) - await self._prisma.connect() - logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") - return self._prisma - async def stop(self) -> None: """Stop the completion consumer.""" self._running = False @@ -136,11 +124,6 @@ class ChatCompletionConsumer: pass self._consumer_task = None - if self._prisma: - await self._prisma.disconnect() - self._prisma = None - logger.info("[COMPLETION] Consumer Prisma client disconnected") - logger.info("Chat completion consumer stopped") async def _consume_messages(self) -> None: @@ -252,7 +235,7 @@ class ChatCompletionConsumer: # XAUTOCLAIM after min_idle_time expires async def _handle_message(self, body: bytes) -> None: - """Handle a completion message using our own Prisma client.""" + """Handle a completion message.""" try: data = orjson.loads(body) message = OperationCompleteMessage(**data) @@ -302,8 +285,7 @@ class ChatCompletionConsumer: message: OperationCompleteMessage, ) -> None: """Handle successful operation completion.""" - prisma = await self._ensure_prisma() - await process_operation_success(task, message.result, prisma) + await process_operation_success(task, message.result) async def _handle_failure( self, @@ -311,8 +293,7 @@ class ChatCompletionConsumer: message: OperationCompleteMessage, ) -> None: """Handle failed operation completion.""" - prisma = await self._ensure_prisma() - await process_operation_failure(task, message.error, prisma) + await process_operation_failure(task, message.error) # Module-level consumer instance diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/copilot/completion_handler.py similarity index 82% rename from autogpt_platform/backend/backend/api/features/chat/completion_handler.py rename to autogpt_platform/backend/backend/copilot/completion_handler.py index 905fa2ddba..fd971c5cc2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py +++ b/autogpt_platform/backend/backend/copilot/completion_handler.py @@ -9,7 +9,8 @@ import logging from typing import Any import orjson -from prisma import Prisma + +from backend.data.db_accessors import chat_db from . import service as chat_service from . import stream_registry @@ -72,48 +73,40 @@ async def _update_tool_message( session_id: str, tool_call_id: str, content: str, - prisma_client: Prisma | None, ) -> None: - """Update tool message in database. + """Update tool message in database using the chat_db accessor. + + Routes through DatabaseManager RPC when Prisma is not directly + connected (e.g. in the CoPilot Executor microservice). Args: session_id: The session ID tool_call_id: The tool call ID to update content: The new content for the message - prisma_client: Optional Prisma client. If None, uses chat_service. Raises: - ToolMessageUpdateError: If the database update fails. The caller should - handle this to avoid marking the task as completed with inconsistent state. + ToolMessageUpdateError: If the database update fails. """ try: - if prisma_client: - # Use provided Prisma client (for consumer with its own connection) - updated_count = await prisma_client.chatmessage.update_many( - where={ - "sessionId": session_id, - "toolCallId": tool_call_id, - }, - data={"content": content}, - ) - # Check if any rows were updated - 0 means message not found - if updated_count == 0: - raise ToolMessageUpdateError( - f"No message found with tool_call_id={tool_call_id} in session {session_id}" - ) - else: - # Use service function (for webhook endpoint) - await chat_service._update_pending_operation( - session_id=session_id, - tool_call_id=tool_call_id, - result=content, + updated = await chat_db().update_tool_message_content( + session_id=session_id, + tool_call_id=tool_call_id, + new_content=content, + ) + if not updated: + raise ToolMessageUpdateError( + f"No message found with tool_call_id=" + f"{tool_call_id} in session {session_id}" ) except ToolMessageUpdateError: raise except Exception as e: - logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + logger.error( + f"[COMPLETION] Failed to update tool message: {e}", + exc_info=True, + ) raise ToolMessageUpdateError( - f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + f"Failed to update tool message for tool call #{tool_call_id}: {e}" ) from e @@ -202,7 +195,6 @@ async def _save_agent_from_result( async def process_operation_success( task: stream_registry.ActiveTask, result: dict | str | None, - prisma_client: Prisma | None = None, ) -> None: """Handle successful operation completion. @@ -212,12 +204,10 @@ async def process_operation_success( Args: task: The active task that completed result: The result data from the operation - prisma_client: Optional Prisma client for database operations. - If None, uses chat_service._update_pending_operation instead. Raises: - ToolMessageUpdateError: If the database update fails. The task will be - marked as failed instead of completed to avoid inconsistent state. + ToolMessageUpdateError: If the database update fails. The task + will be marked as failed instead of completed. """ # For agent generation tools, save the agent to library if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): @@ -250,7 +240,6 @@ async def process_operation_success( session_id=task.session_id, tool_call_id=task.tool_call_id, content=result_str, - prisma_client=prisma_client, ) except ToolMessageUpdateError: # DB update failed - mark task as failed to avoid inconsistent state @@ -293,18 +282,15 @@ async def process_operation_success( async def process_operation_failure( task: stream_registry.ActiveTask, error: str | None, - prisma_client: Prisma | None = None, ) -> None: """Handle failed operation completion. - Publishes the error to the stream registry, updates the database with - the error response, and marks the task as failed. + Publishes the error to the stream registry, updates the database + with the error response, and marks the task as failed. Args: task: The active task that failed error: The error message from the operation - prisma_client: Optional Prisma client for database operations. - If None, uses chat_service._update_pending_operation instead. """ error_msg = error or "Operation failed" @@ -325,7 +311,6 @@ async def process_operation_failure( session_id=task.session_id, tool_call_id=task.tool_call_id, content=error_response.model_dump_json(), - prisma_client=prisma_client, ) except ToolMessageUpdateError: # DB update failed - log but continue with cleanup diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/copilot/config.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/config.py rename to autogpt_platform/backend/backend/copilot/config.py diff --git a/autogpt_platform/backend/backend/api/features/chat/db.py b/autogpt_platform/backend/backend/copilot/db.py similarity index 89% rename from autogpt_platform/backend/backend/api/features/chat/db.py rename to autogpt_platform/backend/backend/copilot/db.py index 303ea0a698..cbf0852500 100644 --- a/autogpt_platform/backend/backend/api/features/chat/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -14,29 +14,27 @@ from prisma.types import ( ChatSessionWhereInput, ) -from backend.data.db import transaction +from backend.data import db from backend.util.json import SafeJson +from .model import ChatMessage, ChatSession, ChatSessionInfo + logger = logging.getLogger(__name__) -async def get_chat_session(session_id: str) -> PrismaChatSession | None: +async def get_chat_session(session_id: str) -> ChatSession | None: """Get a chat session by ID from the database.""" session = await PrismaChatSession.prisma().find_unique( where={"id": session_id}, - include={"Messages": True}, + include={"Messages": {"order_by": {"sequence": "asc"}}}, ) - if session and session.Messages: - # Sort messages by sequence in Python - Prisma Python client doesn't support - # order_by in include clauses (unlike Prisma JS), so we sort after fetching - session.Messages.sort(key=lambda m: m.sequence) - return session + return ChatSession.from_db(session) if session else None async def create_chat_session( session_id: str, user_id: str, -) -> PrismaChatSession: +) -> ChatSessionInfo: """Create a new chat session in the database.""" data = ChatSessionCreateInput( id=session_id, @@ -45,7 +43,8 @@ async def create_chat_session( successfulAgentRuns=SafeJson({}), successfulAgentSchedules=SafeJson({}), ) - return await PrismaChatSession.prisma().create(data=data) + prisma_session = await PrismaChatSession.prisma().create(data=data) + return ChatSessionInfo.from_db(prisma_session) async def update_chat_session( @@ -56,7 +55,7 @@ async def update_chat_session( total_prompt_tokens: int | None = None, total_completion_tokens: int | None = None, title: str | None = None, -) -> PrismaChatSession | None: +) -> ChatSession | None: """Update a chat session's metadata.""" data: ChatSessionUpdateInput = {"updatedAt": datetime.now(UTC)} @@ -76,12 +75,9 @@ async def update_chat_session( session = await PrismaChatSession.prisma().update( where={"id": session_id}, data=data, - include={"Messages": True}, + include={"Messages": {"order_by": {"sequence": "asc"}}}, ) - if session and session.Messages: - # Sort in Python - Prisma Python doesn't support order_by in include clauses - session.Messages.sort(key=lambda m: m.sequence) - return session + return ChatSession.from_db(session) if session else None async def add_chat_message( @@ -94,7 +90,7 @@ async def add_chat_message( refusal: str | None = None, tool_calls: list[dict[str, Any]] | None = None, function_call: dict[str, Any] | None = None, -) -> PrismaChatMessage: +) -> ChatMessage: """Add a message to a chat session.""" # Build input dict dynamically rather than using ChatMessageCreateInput directly # because Prisma's TypedDict validation rejects optional fields set to None. @@ -129,14 +125,14 @@ async def add_chat_message( ), PrismaChatMessage.prisma().create(data=cast(ChatMessageCreateInput, data)), ) - return message + return ChatMessage.from_db(message) async def add_chat_messages_batch( session_id: str, messages: list[dict[str, Any]], start_sequence: int, -) -> list[PrismaChatMessage]: +) -> list[ChatMessage]: """Add multiple messages to a chat session in a batch. Uses a transaction for atomicity - if any message creation fails, @@ -147,7 +143,7 @@ async def add_chat_messages_batch( created_messages = [] - async with transaction() as tx: + async with db.transaction() as tx: for i, msg in enumerate(messages): # Build input dict dynamically rather than using ChatMessageCreateInput # directly because Prisma's TypedDict validation rejects optional fields @@ -187,21 +183,22 @@ async def add_chat_messages_batch( data={"updatedAt": datetime.now(UTC)}, ) - return created_messages + return [ChatMessage.from_db(m) for m in created_messages] async def get_user_chat_sessions( user_id: str, limit: int = 50, offset: int = 0, -) -> list[PrismaChatSession]: +) -> list[ChatSessionInfo]: """Get chat sessions for a user, ordered by most recent.""" - return await PrismaChatSession.prisma().find_many( + prisma_sessions = await PrismaChatSession.prisma().find_many( where={"userId": user_id}, order={"updatedAt": "desc"}, take=limit, skip=offset, ) + return [ChatSessionInfo.from_db(s) for s in prisma_sessions] async def get_user_session_count(user_id: str) -> int: diff --git a/autogpt_platform/backend/test/chat/__init__.py b/autogpt_platform/backend/backend/copilot/executor/__init__.py similarity index 100% rename from autogpt_platform/backend/test/chat/__init__.py rename to autogpt_platform/backend/backend/copilot/executor/__init__.py diff --git a/autogpt_platform/backend/backend/copilot/executor/__main__.py b/autogpt_platform/backend/backend/copilot/executor/__main__.py new file mode 100644 index 0000000000..00d42d6d95 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/executor/__main__.py @@ -0,0 +1,18 @@ +"""Entry point for running the CoPilot Executor service. + +Usage: + python -m backend.copilot.executor +""" + +from backend.app import run_processes + +from .manager import CoPilotExecutor + + +def main(): + """Run the CoPilot Executor service.""" + run_processes(CoPilotExecutor()) + + +if __name__ == "__main__": + main() diff --git a/autogpt_platform/backend/backend/copilot/executor/manager.py b/autogpt_platform/backend/backend/copilot/executor/manager.py new file mode 100644 index 0000000000..212634d342 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/executor/manager.py @@ -0,0 +1,519 @@ +"""CoPilot Executor Manager - main service for CoPilot task execution. + +This module contains the CoPilotExecutor class that consumes chat tasks from +RabbitMQ and processes them using a thread pool, following the graph executor pattern. +""" + +import asyncio +import logging +import os +import threading +import time +import uuid +from concurrent.futures import Future, ThreadPoolExecutor + +from pika.adapters.blocking_connection import BlockingChannel +from pika.exceptions import AMQPChannelError, AMQPConnectionError +from pika.spec import Basic, BasicProperties +from prometheus_client import Gauge, start_http_server + +from backend.data import redis_client as redis +from backend.data.rabbitmq import SyncRabbitMQ +from backend.executor.cluster_lock import ClusterLock +from backend.util.decorator import error_logged +from backend.util.logging import TruncatedLogger +from backend.util.process import AppProcess +from backend.util.retry import continuous_retry +from backend.util.settings import Settings + +from .processor import execute_copilot_task, init_worker +from .utils import ( + COPILOT_CANCEL_QUEUE_NAME, + COPILOT_EXECUTION_QUEUE_NAME, + GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS, + CancelCoPilotEvent, + CoPilotExecutionEntry, + create_copilot_queue_config, +) + +logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]") +settings = Settings() + +# Prometheus metrics +active_tasks_gauge = Gauge( + "copilot_executor_active_tasks", + "Number of active CoPilot tasks", +) +pool_size_gauge = Gauge( + "copilot_executor_pool_size", + "Maximum number of CoPilot executor workers", +) +utilization_gauge = Gauge( + "copilot_executor_utilization_ratio", + "Ratio of active tasks to pool size", +) + + +class CoPilotExecutor(AppProcess): + """CoPilot Executor service for processing chat generation tasks. + + This service consumes tasks from RabbitMQ, processes them using a thread pool, + and publishes results to Redis Streams. It follows the graph executor pattern + for reliable message handling and graceful shutdown. + + Key features: + - RabbitMQ-based task distribution with manual acknowledgment + - Thread pool executor for concurrent task processing + - Cluster lock for duplicate prevention across pods + - Graceful shutdown with timeout for in-flight tasks + - FANOUT exchange for cancellation broadcast + """ + + def __init__(self): + super().__init__() + self.pool_size = settings.config.num_copilot_workers + self.active_tasks: dict[str, tuple[Future, threading.Event]] = {} + self.executor_id = str(uuid.uuid4()) + + self._executor = None + self._stop_consuming = None + + self._cancel_thread = None + self._cancel_client = None + self._run_thread = None + self._run_client = None + + self._task_locks: dict[str, ClusterLock] = {} + self._active_tasks_lock = threading.Lock() + + # ============ Main Entry Points (AppProcess interface) ============ # + + def run(self): + """Main service loop - consume from RabbitMQ.""" + logger.info(f"Pod assigned executor_id: {self.executor_id}") + logger.info(f"Spawn max-{self.pool_size} workers...") + + pool_size_gauge.set(self.pool_size) + self._update_metrics() + start_http_server(settings.config.copilot_executor_port) + + self.cancel_thread.start() + self.run_thread.start() + + while True: + time.sleep(1e5) + + def cleanup(self): + """Graceful shutdown with active execution waiting.""" + pid = os.getpid() + logger.info(f"[cleanup {pid}] Starting graceful shutdown...") + + # Signal the consumer thread to stop + try: + self.stop_consuming.set() + run_channel = self.run_client.get_channel() + run_channel.connection.add_callback_threadsafe( + lambda: run_channel.stop_consuming() + ) + logger.info(f"[cleanup {pid}] Consumer has been signaled to stop") + except Exception as e: + logger.error(f"[cleanup {pid}] Error stopping consumer: {e}") + + # Wait for active executions to complete + if self.active_tasks: + logger.info( + f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..." + ) + + start_time = time.monotonic() + last_refresh = start_time + lock_refresh_interval = settings.config.cluster_lock_timeout / 10 + + while ( + self.active_tasks + and (time.monotonic() - start_time) < GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS + ): + self._cleanup_completed_tasks() + if not self.active_tasks: + break + + # Refresh cluster locks periodically + current_time = time.monotonic() + if current_time - last_refresh >= lock_refresh_interval: + for lock in list(self._task_locks.values()): + try: + lock.refresh() + except Exception as e: + logger.warning( + f"[cleanup {pid}] Failed to refresh lock: {e}" + ) + last_refresh = current_time + + logger.info( + f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..." + ) + time.sleep(10.0) + + # Stop message consumers + if self._run_thread: + self._stop_message_consumers( + self._run_thread, self.run_client, "[cleanup][run]" + ) + if self._cancel_thread: + self._stop_message_consumers( + self._cancel_thread, self.cancel_client, "[cleanup][cancel]" + ) + + # Shutdown executor + if self._executor: + logger.info(f"[cleanup {pid}] Shutting down executor...") + self._executor.shutdown(wait=False) + + # Close async resources (workspace storage aiohttp session, etc.) + try: + from backend.util.workspace_storage import shutdown_workspace_storage + + loop = asyncio.new_event_loop() + loop.run_until_complete(shutdown_workspace_storage()) + loop.close() + except Exception as e: + logger.warning(f"[cleanup {pid}] Error closing workspace storage: {e}") + + # Release any remaining locks + for task_id, lock in list(self._task_locks.items()): + try: + lock.release() + logger.info(f"[cleanup {pid}] Released lock for {task_id}") + except Exception as e: + logger.error( + f"[cleanup {pid}] Failed to release lock for {task_id}: {e}" + ) + + logger.info(f"[cleanup {pid}] Graceful shutdown completed") + + # ============ RabbitMQ Consumer Methods ============ # + + @continuous_retry() + def _consume_cancel(self): + """Consume cancellation messages from FANOUT exchange.""" + if self.stop_consuming.is_set() and not self.active_tasks: + logger.info("Stop reconnecting cancel consumer - service cleaned up") + return + + if not self.cancel_client.is_ready: + self.cancel_client.disconnect() + self.cancel_client.connect() + + # Check again after connect - shutdown may have been requested + if self.stop_consuming.is_set() and not self.active_tasks: + logger.info("Stop consuming requested during reconnect - disconnecting") + self.cancel_client.disconnect() + return + + cancel_channel = self.cancel_client.get_channel() + cancel_channel.basic_consume( + queue=COPILOT_CANCEL_QUEUE_NAME, + on_message_callback=self._handle_cancel_message, + auto_ack=True, + ) + logger.info("Starting to consume cancel messages...") + cancel_channel.start_consuming() + if not self.stop_consuming.is_set() or self.active_tasks: + raise RuntimeError("Cancel message consumer stopped unexpectedly") + logger.info("Cancel message consumer stopped gracefully") + + @continuous_retry() + def _consume_run(self): + """Consume run messages from DIRECT exchange.""" + if self.stop_consuming.is_set(): + logger.info("Stop reconnecting run consumer - service cleaned up") + return + + if not self.run_client.is_ready: + self.run_client.disconnect() + self.run_client.connect() + + # Check again after connect - shutdown may have been requested + if self.stop_consuming.is_set(): + logger.info("Stop consuming requested during reconnect - disconnecting") + self.run_client.disconnect() + return + + run_channel = self.run_client.get_channel() + run_channel.basic_qos(prefetch_count=self.pool_size) + + run_channel.basic_consume( + queue=COPILOT_EXECUTION_QUEUE_NAME, + on_message_callback=self._handle_run_message, + auto_ack=False, + consumer_tag="copilot_execution_consumer", + ) + logger.info("Starting to consume run messages...") + run_channel.start_consuming() + if not self.stop_consuming.is_set(): + raise RuntimeError("Run message consumer stopped unexpectedly") + logger.info("Run message consumer stopped gracefully") + + # ============ Message Handlers ============ # + + @error_logged(swallow=True) + def _handle_cancel_message( + self, + _channel: BlockingChannel, + _method: Basic.Deliver, + _properties: BasicProperties, + body: bytes, + ): + """Handle cancel message from FANOUT exchange.""" + request = CancelCoPilotEvent.model_validate_json(body) + task_id = request.task_id + if not task_id: + logger.warning("Cancel message missing 'task_id'") + return + if task_id not in self.active_tasks: + logger.debug(f"Cancel received for {task_id} but not active") + return + + _, cancel_event = self.active_tasks[task_id] + logger.info(f"Received cancel for {task_id}") + if not cancel_event.is_set(): + cancel_event.set() + else: + logger.debug(f"Cancel already set for {task_id}") + + def _handle_run_message( + self, + _channel: BlockingChannel, + method: Basic.Deliver, + _properties: BasicProperties, + body: bytes, + ): + """Handle run message from DIRECT exchange.""" + delivery_tag = method.delivery_tag + # Capture the channel used at message delivery time to ensure we ack + # on the correct channel. Delivery tags are channel-scoped and become + # invalid if the channel is recreated after reconnection. + delivery_channel = _channel + + def ack_message(reject: bool, requeue: bool): + """Acknowledge or reject the message. + + Uses the channel from the original message delivery. If the channel + is no longer open (e.g., after reconnection), logs a warning and + skips the ack - RabbitMQ will redeliver the message automatically. + """ + try: + if not delivery_channel.is_open: + logger.warning( + f"Channel closed, cannot ack delivery_tag={delivery_tag}. " + "Message will be redelivered by RabbitMQ." + ) + return + + if reject: + delivery_channel.connection.add_callback_threadsafe( + lambda: delivery_channel.basic_nack( + delivery_tag, requeue=requeue + ) + ) + else: + delivery_channel.connection.add_callback_threadsafe( + lambda: delivery_channel.basic_ack(delivery_tag) + ) + except (AMQPChannelError, AMQPConnectionError) as e: + # Channel/connection errors indicate stale delivery tag - don't retry + logger.warning( + f"Cannot ack delivery_tag={delivery_tag} due to channel/connection " + f"error: {e}. Message will be redelivered by RabbitMQ." + ) + except Exception as e: + # Other errors might be transient, but log and skip to avoid blocking + logger.error( + f"Unexpected error acking delivery_tag={delivery_tag}: {e}" + ) + + # Check if we're shutting down + if self.stop_consuming.is_set(): + logger.info("Rejecting new task during shutdown") + ack_message(reject=True, requeue=True) + return + + # Check if we can accept more tasks + self._cleanup_completed_tasks() + if len(self.active_tasks) >= self.pool_size: + ack_message(reject=True, requeue=True) + return + + try: + entry = CoPilotExecutionEntry.model_validate_json(body) + except Exception as e: + logger.error(f"Could not parse run message: {e}, body={body}") + ack_message(reject=True, requeue=False) + return + + task_id = entry.task_id + + # Check for local duplicate - task is already running on this executor + if task_id in self.active_tasks: + logger.warning( + f"Task {task_id} already running locally, rejecting duplicate" + ) + ack_message(reject=True, requeue=False) + return + + # Try to acquire cluster-wide lock + cluster_lock = ClusterLock( + redis=redis.get_redis(), + key=f"copilot:task:{task_id}:lock", + owner_id=self.executor_id, + timeout=settings.config.cluster_lock_timeout, + ) + current_owner = cluster_lock.try_acquire() + if current_owner != self.executor_id: + if current_owner is not None: + logger.warning(f"Task {task_id} already running on pod {current_owner}") + ack_message(reject=True, requeue=False) + else: + logger.warning( + f"Could not acquire lock for {task_id} - Redis unavailable" + ) + ack_message(reject=True, requeue=True) + return + + # Execute the task + try: + self._task_locks[task_id] = cluster_lock + + logger.info( + f"Acquired cluster lock for {task_id}, executor_id={self.executor_id}" + ) + + cancel_event = threading.Event() + future = self.executor.submit( + execute_copilot_task, entry, cancel_event, cluster_lock + ) + self.active_tasks[task_id] = (future, cancel_event) + except Exception as e: + logger.warning(f"Failed to setup execution for {task_id}: {e}") + cluster_lock.release() + if task_id in self._task_locks: + del self._task_locks[task_id] + ack_message(reject=True, requeue=True) + return + + self._update_metrics() + + def on_run_done(f: Future): + logger.info(f"Run completed for {task_id}") + try: + if exec_error := f.exception(): + logger.error(f"Execution for {task_id} failed: {exec_error}") + # Don't requeue failed tasks - they've been marked as failed + # in the stream registry. Requeuing would cause infinite retries + # for deterministic failures. + ack_message(reject=True, requeue=False) + else: + ack_message(reject=False, requeue=False) + except BaseException as e: + logger.exception(f"Error in run completion callback: {e}") + finally: + # Release the cluster lock + if task_id in self._task_locks: + logger.info(f"Releasing cluster lock for {task_id}") + self._task_locks[task_id].release() + del self._task_locks[task_id] + self._cleanup_completed_tasks() + + future.add_done_callback(on_run_done) + + # ============ Helper Methods ============ # + + def _cleanup_completed_tasks(self) -> list[str]: + """Remove completed futures from active_tasks and update metrics.""" + completed_tasks = [] + with self._active_tasks_lock: + for task_id, (future, _) in list(self.active_tasks.items()): + if future.done(): + completed_tasks.append(task_id) + self.active_tasks.pop(task_id, None) + logger.info(f"Cleaned up completed task {task_id}") + + self._update_metrics() + return completed_tasks + + def _update_metrics(self): + """Update Prometheus metrics.""" + active_count = len(self.active_tasks) + active_tasks_gauge.set(active_count) + if self.stop_consuming.is_set(): + utilization_gauge.set(1.0) + else: + utilization_gauge.set( + active_count / self.pool_size if self.pool_size > 0 else 0 + ) + + def _stop_message_consumers( + self, thread: threading.Thread, client: SyncRabbitMQ, prefix: str + ): + """Stop a message consumer thread.""" + try: + channel = client.get_channel() + channel.connection.add_callback_threadsafe(lambda: channel.stop_consuming()) + + thread.join(timeout=300) + if thread.is_alive(): + logger.error( + f"{prefix} Thread did not finish in time, forcing disconnect" + ) + + client.disconnect() + logger.info(f"{prefix} Client disconnected") + except Exception as e: + logger.error(f"{prefix} Error disconnecting client: {e}") + + # ============ Lazy-initialized Properties ============ # + + @property + def cancel_thread(self) -> threading.Thread: + if self._cancel_thread is None: + self._cancel_thread = threading.Thread( + target=lambda: self._consume_cancel(), + daemon=True, + ) + return self._cancel_thread + + @property + def run_thread(self) -> threading.Thread: + if self._run_thread is None: + self._run_thread = threading.Thread( + target=lambda: self._consume_run(), + daemon=True, + ) + return self._run_thread + + @property + def stop_consuming(self) -> threading.Event: + if self._stop_consuming is None: + self._stop_consuming = threading.Event() + return self._stop_consuming + + @property + def executor(self) -> ThreadPoolExecutor: + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self.pool_size, + initializer=init_worker, + ) + return self._executor + + @property + def cancel_client(self) -> SyncRabbitMQ: + if self._cancel_client is None: + self._cancel_client = SyncRabbitMQ(create_copilot_queue_config()) + return self._cancel_client + + @property + def run_client(self) -> SyncRabbitMQ: + if self._run_client is None: + self._run_client = SyncRabbitMQ(create_copilot_queue_config()) + return self._run_client diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py new file mode 100644 index 0000000000..eb941d5efd --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -0,0 +1,253 @@ +"""CoPilot execution processor - per-worker execution logic. + +This module contains the processor class that handles CoPilot task execution +in a thread-local context, following the graph executor pattern. +""" + +import asyncio +import logging +import threading +import time + +from backend.copilot import service as copilot_service +from backend.copilot import stream_registry +from backend.copilot.config import ChatConfig +from backend.copilot.response_model import StreamError, StreamFinish, StreamFinishStep +from backend.copilot.sdk import service as sdk_service +from backend.executor.cluster_lock import ClusterLock +from backend.util.decorator import error_logged +from backend.util.feature_flag import Flag, is_feature_enabled +from backend.util.logging import TruncatedLogger, configure_logging +from backend.util.process import set_service_name +from backend.util.retry import func_retry + +from .utils import CoPilotExecutionEntry, CoPilotLogMetadata + +logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]") + + +# ============ Module Entry Points ============ # + +# Thread-local storage for processor instances +_tls = threading.local() + + +def execute_copilot_task( + entry: CoPilotExecutionEntry, + cancel: threading.Event, + cluster_lock: ClusterLock, +): + """Execute a CoPilot task using the thread-local processor. + + This function is the entry point called by the thread pool executor. + + Args: + entry: The task payload + cancel: Threading event to signal cancellation + cluster_lock: Distributed lock for this execution + """ + processor: CoPilotProcessor = _tls.processor + return processor.execute(entry, cancel, cluster_lock) + + +def init_worker(): + """Initialize the processor for the current worker thread. + + This function is called by the thread pool executor when a new worker + thread is created. It ensures each worker has its own processor instance. + """ + _tls.processor = CoPilotProcessor() + _tls.processor.on_executor_start() + + +# ============ Processor Class ============ # + + +class CoPilotProcessor: + """Per-worker execution logic for CoPilot tasks. + + This class is instantiated once per worker thread and handles the execution + of CoPilot chat generation tasks. It maintains an async event loop for + running the async service code. + + The execution flow: + 1. CoPilot task is picked from RabbitMQ queue + 2. Manager submits task to thread pool + 3. Processor executes the task in its event loop + 4. Results are published to Redis Streams + """ + + @func_retry + def on_executor_start(self): + """Initialize the processor when the worker thread starts. + + This method is called once per worker thread to set up the async event + loop and initialize any required resources. + + Database is accessed only through DatabaseManager, so we don't need to connect + to Prisma directly. + """ + configure_logging() + set_service_name("CoPilotExecutor") + self.tid = threading.get_ident() + self.execution_loop = asyncio.new_event_loop() + self.execution_thread = threading.Thread( + target=self.execution_loop.run_forever, daemon=True + ) + self.execution_thread.start() + + logger.info(f"[CoPilotExecutor] Worker {self.tid} started") + + @error_logged(swallow=False) + def execute( + self, + entry: CoPilotExecutionEntry, + cancel: threading.Event, + cluster_lock: ClusterLock, + ): + """Execute a CoPilot task. + + This is the main entry point for task execution. It runs the async + execution logic in the worker's event loop and handles errors. + + Args: + entry: The task payload containing session and message info + cancel: Threading event to signal cancellation + cluster_lock: Distributed lock to prevent duplicate execution + """ + log = CoPilotLogMetadata( + logging.getLogger(__name__), + task_id=entry.task_id, + session_id=entry.session_id, + user_id=entry.user_id, + ) + log.info("Starting execution") + + start_time = time.monotonic() + + try: + # Run the async execution in our event loop + future = asyncio.run_coroutine_threadsafe( + self._execute_async(entry, cancel, cluster_lock, log), + self.execution_loop, + ) + + # Wait for completion, checking cancel periodically + while not future.done(): + try: + future.result(timeout=1.0) + except asyncio.TimeoutError: + if cancel.is_set(): + log.info("Cancellation requested") + future.cancel() + break + # Refresh cluster lock to maintain ownership + cluster_lock.refresh() + + if not future.cancelled(): + # Get result to propagate any exceptions + future.result() + + elapsed = time.monotonic() - start_time + log.info(f"Execution completed in {elapsed:.2f}s") + + except Exception as e: + elapsed = time.monotonic() - start_time + log.error(f"Execution failed after {elapsed:.2f}s: {e}") + # Note: _execute_async already marks the task as failed before re-raising, + # so we don't call _mark_task_failed here to avoid duplicate error events. + raise + + async def _execute_async( + self, + entry: CoPilotExecutionEntry, + cancel: threading.Event, + cluster_lock: ClusterLock, + log: CoPilotLogMetadata, + ): + """Async execution logic for CoPilot task. + + This method calls the existing stream_chat_completion service function + and publishes results to the stream registry. + + Args: + entry: The task payload + cancel: Threading event to signal cancellation + cluster_lock: Distributed lock for refresh + log: Structured logger for this task + """ + last_refresh = time.monotonic() + refresh_interval = 30.0 # Refresh lock every 30 seconds + + try: + # Choose service based on LaunchDarkly flag + config = ChatConfig() + use_sdk = await is_feature_enabled( + Flag.COPILOT_SDK, + entry.user_id or "anonymous", + default=config.use_claude_agent_sdk, + ) + stream_fn = ( + sdk_service.stream_chat_completion_sdk + if use_sdk + else copilot_service.stream_chat_completion + ) + log.info(f"Using {'SDK' if use_sdk else 'standard'} service") + + # Stream chat completion and publish chunks to Redis + async for chunk in stream_fn( + session_id=entry.session_id, + message=entry.message if entry.message else None, + is_user_message=entry.is_user_message, + user_id=entry.user_id, + context=entry.context, + ): + # Check for cancellation + if cancel.is_set(): + log.info("Cancelled during streaming") + await stream_registry.publish_chunk( + entry.task_id, StreamError(errorText="Operation cancelled") + ) + await stream_registry.publish_chunk( + entry.task_id, StreamFinishStep() + ) + await stream_registry.publish_chunk(entry.task_id, StreamFinish()) + await stream_registry.mark_task_completed( + entry.task_id, status="failed" + ) + return + + # Refresh cluster lock periodically + current_time = time.monotonic() + if current_time - last_refresh >= refresh_interval: + cluster_lock.refresh() + last_refresh = current_time + + # Publish chunk to stream registry + await stream_registry.publish_chunk(entry.task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(entry.task_id, status="completed") + log.info("Task completed successfully") + + except asyncio.CancelledError: + log.info("Task cancelled") + await stream_registry.mark_task_completed(entry.task_id, status="failed") + raise + + except Exception as e: + log.error(f"Task failed: {e}") + await self._mark_task_failed(entry.task_id, str(e)) + raise + + async def _mark_task_failed(self, task_id: str, error_message: str): + """Mark a task as failed and publish error to stream registry.""" + try: + await stream_registry.publish_chunk( + task_id, StreamError(errorText=error_message) + ) + await stream_registry.publish_chunk(task_id, StreamFinishStep()) + await stream_registry.publish_chunk(task_id, StreamFinish()) + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception as e: + logger.error(f"Failed to mark task {task_id} as failed: {e}") diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py new file mode 100644 index 0000000000..60d9cb22bf --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -0,0 +1,207 @@ +"""RabbitMQ queue configuration for CoPilot executor. + +Defines two exchanges and queues following the graph executor pattern: +- 'copilot_execution' (DIRECT) for chat generation tasks +- 'copilot_cancel' (FANOUT) for cancellation requests +""" + +import logging + +from pydantic import BaseModel + +from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig +from backend.util.logging import TruncatedLogger, is_structured_logging_enabled + +logger = logging.getLogger(__name__) + + +# ============ Logging Helper ============ # + + +class CoPilotLogMetadata(TruncatedLogger): + """Structured logging helper for CoPilot executor. + + In cloud environments (structured logging enabled), uses a simple prefix + and passes metadata via json_fields. In local environments, uses a detailed + prefix with all metadata key-value pairs for easier debugging. + + Args: + logger: The underlying logger instance + max_length: Maximum log message length before truncation + **kwargs: Metadata key-value pairs (e.g., task_id="abc", session_id="xyz") + These are added to json_fields in cloud mode, or to the prefix in local mode. + """ + + def __init__( + self, + logger: logging.Logger, + max_length: int = 1000, + **kwargs: str | None, + ): + # Filter out None values + metadata = {k: v for k, v in kwargs.items() if v is not None} + metadata["component"] = "CoPilotExecutor" + + if is_structured_logging_enabled(): + prefix = "[CoPilotExecutor]" + else: + # Build prefix from metadata key-value pairs + meta_parts = "|".join( + f"{k}:{v}" for k, v in metadata.items() if k != "component" + ) + prefix = ( + f"[CoPilotExecutor|{meta_parts}]" if meta_parts else "[CoPilotExecutor]" + ) + + super().__init__( + logger, + max_length=max_length, + prefix=prefix, + metadata=metadata, + ) + + +# ============ Exchange and Queue Configuration ============ # + +COPILOT_EXECUTION_EXCHANGE = Exchange( + name="copilot_execution", + type=ExchangeType.DIRECT, + durable=True, + auto_delete=False, +) +COPILOT_EXECUTION_QUEUE_NAME = "copilot_execution_queue" +COPILOT_EXECUTION_ROUTING_KEY = "copilot.run" + +COPILOT_CANCEL_EXCHANGE = Exchange( + name="copilot_cancel", + type=ExchangeType.FANOUT, + durable=True, + auto_delete=False, +) +COPILOT_CANCEL_QUEUE_NAME = "copilot_cancel_queue" + +# CoPilot operations can include extended thinking and agent generation +# which may take 30+ minutes to complete +COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour + +# Graceful shutdown timeout - allow in-flight operations to complete +GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes + + +def create_copilot_queue_config() -> RabbitMQConfig: + """Create RabbitMQ configuration for CoPilot executor. + + Defines two exchanges and queues: + - 'copilot_execution' (DIRECT) for chat generation tasks + - 'copilot_cancel' (FANOUT) for cancellation requests + + Returns: + RabbitMQConfig with exchanges and queues defined + """ + run_queue = Queue( + name=COPILOT_EXECUTION_QUEUE_NAME, + exchange=COPILOT_EXECUTION_EXCHANGE, + routing_key=COPILOT_EXECUTION_ROUTING_KEY, + durable=True, + auto_delete=False, + arguments={ + # Extended consumer timeout for long-running LLM operations + # Default 30-minute timeout is insufficient for extended thinking + # and agent generation which can take 30+ minutes + "x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS + * 1000, + }, + ) + cancel_queue = Queue( + name=COPILOT_CANCEL_QUEUE_NAME, + exchange=COPILOT_CANCEL_EXCHANGE, + routing_key="", # not used for FANOUT + durable=True, + auto_delete=False, + ) + return RabbitMQConfig( + vhost="/", + exchanges=[COPILOT_EXECUTION_EXCHANGE, COPILOT_CANCEL_EXCHANGE], + queues=[run_queue, cancel_queue], + ) + + +# ============ Message Models ============ # + + +class CoPilotExecutionEntry(BaseModel): + """Task payload for CoPilot AI generation. + + This model represents a chat generation task to be processed by the executor. + """ + + task_id: str + """Unique identifier for this task (used for stream registry)""" + + session_id: str + """Chat session ID""" + + user_id: str | None + """User ID (may be None for anonymous users)""" + + operation_id: str + """Operation ID for webhook callbacks and completion tracking""" + + message: str + """User's message to process""" + + is_user_message: bool = True + """Whether the message is from the user (vs system/assistant)""" + + context: dict[str, str] | None = None + """Optional context for the message (e.g., {url: str, content: str})""" + + +class CancelCoPilotEvent(BaseModel): + """Event to cancel a CoPilot operation.""" + + task_id: str + """Task ID to cancel""" + + +# ============ Queue Publishing Helpers ============ # + + +async def enqueue_copilot_task( + task_id: str, + session_id: str, + user_id: str | None, + operation_id: str, + message: str, + is_user_message: bool = True, + context: dict[str, str] | None = None, +) -> None: + """Enqueue a CoPilot task for processing by the executor service. + + Args: + task_id: Unique identifier for this task (used for stream registry) + session_id: Chat session ID + user_id: User ID (may be None for anonymous users) + operation_id: Operation ID for webhook callbacks and completion tracking + message: User's message to process + is_user_message: Whether the message is from the user (vs system/assistant) + context: Optional context for the message (e.g., {url: str, content: str}) + """ + from backend.util.clients import get_async_copilot_queue + + entry = CoPilotExecutionEntry( + task_id=task_id, + session_id=session_id, + user_id=user_id, + operation_id=operation_id, + message=message, + is_user_message=is_user_message, + context=context, + ) + + queue_client = await get_async_copilot_queue() + await queue_client.publish_message( + routing_key=COPILOT_EXECUTION_ROUTING_KEY, + message=entry.model_dump_json(), + exchange=COPILOT_EXECUTION_EXCHANGE, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/model.py b/autogpt_platform/backend/backend/copilot/model.py similarity index 85% rename from autogpt_platform/backend/backend/api/features/chat/model.py rename to autogpt_platform/backend/backend/copilot/model.py index 30ac27aece..60f05df3cc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/model.py +++ b/autogpt_platform/backend/backend/copilot/model.py @@ -2,7 +2,7 @@ import asyncio import logging import uuid from datetime import UTC, datetime -from typing import Any, cast +from typing import Any, Self, cast from weakref import WeakValueDictionary from openai.types.chat import ( @@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage from prisma.models import ChatSession as PrismaChatSession from pydantic import BaseModel +from backend.data.db_accessors import chat_db from backend.data.redis_client import get_redis_async from backend.util import json from backend.util.exceptions import DatabaseError, RedisError -from . import db as chat_db from .config import ChatConfig logger = logging.getLogger(__name__) config = ChatConfig() -def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any: - """Parse a JSON field that may be stored as string or already parsed.""" - if value is None: - return default - if isinstance(value, str): - return json.loads(value) - return value - - # Redis cache key prefix for chat sessions CHAT_SESSION_CACHE_PREFIX = "chat:session:" @@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str: return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}" -# Session-level locks to prevent race conditions during concurrent upserts. -# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced, -# preventing unbounded memory growth while maintaining lock semantics for active sessions. -# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after -# async with lock: completes). Explicit cleanup also occurs in delete_chat_session(). -_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() -_session_locks_mutex = asyncio.Lock() - - -async def _get_session_lock(session_id: str) -> asyncio.Lock: - """Get or create a lock for a specific session to prevent concurrent upserts. - - Uses WeakValueDictionary for automatic cleanup: locks are garbage collected - when no coroutine holds a reference to them, preventing memory leaks from - unbounded growth of session locks. - """ - async with _session_locks_mutex: - lock = _session_locks.get(session_id) - if lock is None: - lock = asyncio.Lock() - _session_locks[session_id] = lock - return lock +# ===================== Chat data models ===================== # class ChatMessage(BaseModel): @@ -85,6 +55,19 @@ class ChatMessage(BaseModel): tool_calls: list[dict] | None = None function_call: dict | None = None + @staticmethod + def from_db(prisma_message: PrismaChatMessage) -> "ChatMessage": + """Convert a Prisma ChatMessage to a Pydantic ChatMessage.""" + return ChatMessage( + role=prisma_message.role, + content=prisma_message.content, + name=prisma_message.name, + tool_call_id=prisma_message.toolCallId, + refusal=prisma_message.refusal, + tool_calls=_parse_json_field(prisma_message.toolCalls), + function_call=_parse_json_field(prisma_message.functionCall), + ) + class Usage(BaseModel): prompt_tokens: int @@ -92,11 +75,10 @@ class Usage(BaseModel): total_tokens: int -class ChatSession(BaseModel): +class ChatSessionInfo(BaseModel): session_id: str user_id: str title: str | None = None - messages: list[ChatMessage] usage: list[Usage] credentials: dict[str, dict] = {} # Map of provider -> credential metadata started_at: datetime @@ -104,60 +86,9 @@ class ChatSession(BaseModel): successful_agent_runs: dict[str, int] = {} successful_agent_schedules: dict[str, int] = {} - def add_tool_call_to_current_turn(self, tool_call: dict) -> None: - """Attach a tool_call to the current turn's assistant message. - - Searches backwards for the most recent assistant message (stopping at - any user message boundary). If found, appends the tool_call to it. - Otherwise creates a new assistant message with the tool_call. - """ - for msg in reversed(self.messages): - if msg.role == "user": - break - if msg.role == "assistant": - if not msg.tool_calls: - msg.tool_calls = [] - msg.tool_calls.append(tool_call) - return - - self.messages.append( - ChatMessage(role="assistant", content="", tool_calls=[tool_call]) - ) - - @staticmethod - def new(user_id: str) -> "ChatSession": - return ChatSession( - session_id=str(uuid.uuid4()), - user_id=user_id, - title=None, - messages=[], - usage=[], - credentials={}, - started_at=datetime.now(UTC), - updated_at=datetime.now(UTC), - ) - - @staticmethod - def from_db( - prisma_session: PrismaChatSession, - prisma_messages: list[PrismaChatMessage] | None = None, - ) -> "ChatSession": - """Convert Prisma models to Pydantic ChatSession.""" - messages = [] - if prisma_messages: - for msg in prisma_messages: - messages.append( - ChatMessage( - role=msg.role, - content=msg.content, - name=msg.name, - tool_call_id=msg.toolCallId, - refusal=msg.refusal, - tool_calls=_parse_json_field(msg.toolCalls), - function_call=_parse_json_field(msg.functionCall), - ) - ) - + @classmethod + def from_db(cls, prisma_session: PrismaChatSession) -> Self: + """Convert Prisma ChatSession to Pydantic ChatSession.""" # Parse JSON fields from Prisma credentials = _parse_json_field(prisma_session.credentials, default={}) successful_agent_runs = _parse_json_field( @@ -179,11 +110,10 @@ class ChatSession(BaseModel): ) ) - return ChatSession( + return cls( session_id=prisma_session.id, user_id=prisma_session.userId, title=prisma_session.title, - messages=messages, usage=usage, credentials=credentials, started_at=prisma_session.createdAt, @@ -192,46 +122,55 @@ class ChatSession(BaseModel): successful_agent_schedules=successful_agent_schedules, ) - @staticmethod - def _merge_consecutive_assistant_messages( - messages: list[ChatCompletionMessageParam], - ) -> list[ChatCompletionMessageParam]: - """Merge consecutive assistant messages into single messages. - Long-running tool flows can create split assistant messages: one with - text content and another with tool_calls. Anthropic's API requires - tool_result blocks to reference a tool_use in the immediately preceding - assistant message, so these splits cause 400 errors via OpenRouter. +class ChatSession(ChatSessionInfo): + messages: list[ChatMessage] + + @classmethod + def new(cls, user_id: str) -> Self: + return cls( + session_id=str(uuid.uuid4()), + user_id=user_id, + title=None, + messages=[], + usage=[], + credentials={}, + started_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + @classmethod + def from_db(cls, prisma_session: PrismaChatSession) -> Self: + """Convert Prisma ChatSession to Pydantic ChatSession.""" + if prisma_session.Messages is None: + raise ValueError( + f"Prisma session {prisma_session.id} is missing Messages relation" + ) + + return cls( + **ChatSessionInfo.from_db(prisma_session).model_dump(), + messages=[ChatMessage.from_db(m) for m in prisma_session.Messages], + ) + + def add_tool_call_to_current_turn(self, tool_call: dict) -> None: + """Attach a tool_call to the current turn's assistant message. + + Searches backwards for the most recent assistant message (stopping at + any user message boundary). If found, appends the tool_call to it. + Otherwise creates a new assistant message with the tool_call. """ - if len(messages) < 2: - return messages + for msg in reversed(self.messages): + if msg.role == "user": + break + if msg.role == "assistant": + if not msg.tool_calls: + msg.tool_calls = [] + msg.tool_calls.append(tool_call) + return - result: list[ChatCompletionMessageParam] = [messages[0]] - for msg in messages[1:]: - prev = result[-1] - if prev.get("role") != "assistant" or msg.get("role") != "assistant": - result.append(msg) - continue - - prev = cast(ChatCompletionAssistantMessageParam, prev) - curr = cast(ChatCompletionAssistantMessageParam, msg) - - curr_content = curr.get("content") or "" - if curr_content: - prev_content = prev.get("content") or "" - prev["content"] = ( - f"{prev_content}\n{curr_content}" if prev_content else curr_content - ) - - curr_tool_calls = curr.get("tool_calls") - if curr_tool_calls: - prev_tool_calls = prev.get("tool_calls") - prev["tool_calls"] = ( - list(prev_tool_calls) + list(curr_tool_calls) - if prev_tool_calls - else list(curr_tool_calls) - ) - return result + self.messages.append( + ChatMessage(role="assistant", content="", tool_calls=[tool_call]) + ) def to_openai_messages(self) -> list[ChatCompletionMessageParam]: messages = [] @@ -321,38 +260,68 @@ class ChatSession(BaseModel): ) return self._merge_consecutive_assistant_messages(messages) + @staticmethod + def _merge_consecutive_assistant_messages( + messages: list[ChatCompletionMessageParam], + ) -> list[ChatCompletionMessageParam]: + """Merge consecutive assistant messages into single messages. -async def _get_session_from_cache(session_id: str) -> ChatSession | None: - """Get a chat session from Redis cache.""" - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - raw_session: bytes | None = await async_redis.get(redis_key) + Long-running tool flows can create split assistant messages: one with + text content and another with tool_calls. Anthropic's API requires + tool_result blocks to reference a tool_use in the immediately preceding + assistant message, so these splits cause 400 errors via OpenRouter. + """ + if len(messages) < 2: + return messages - if raw_session is None: - return None + result: list[ChatCompletionMessageParam] = [messages[0]] + for msg in messages[1:]: + prev = result[-1] + if prev.get("role") != "assistant" or msg.get("role") != "assistant": + result.append(msg) + continue - try: - session = ChatSession.model_validate_json(raw_session) - logger.info( - f"[CACHE] Loaded session {session_id}: {len(session.messages)} messages, " - f"last_roles={[m.role for m in session.messages[-3:]]}" # Last 3 roles - ) - return session - except Exception as e: - logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True) - raise RedisError(f"Corrupted session data for {session_id}") from e + prev = cast(ChatCompletionAssistantMessageParam, prev) + curr = cast(ChatCompletionAssistantMessageParam, msg) + + curr_content = curr.get("content") or "" + if curr_content: + prev_content = prev.get("content") or "" + prev["content"] = ( + f"{prev_content}\n{curr_content}" if prev_content else curr_content + ) + + curr_tool_calls = curr.get("tool_calls") + if curr_tool_calls: + prev_tool_calls = prev.get("tool_calls") + prev["tool_calls"] = ( + list(prev_tool_calls) + list(curr_tool_calls) + if prev_tool_calls + else list(curr_tool_calls) + ) + return result -async def _cache_session(session: ChatSession) -> None: - """Cache a chat session in Redis.""" - redis_key = _get_session_cache_key(session.session_id) - async_redis = await get_redis_async() - await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) +def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any: + """Parse a JSON field that may be stored as string or already parsed.""" + if value is None: + return default + if isinstance(value, str): + return json.loads(value) + return value + + +# ================ Chat cache + DB operations ================ # + +# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not +# connected directly. async def cache_chat_session(session: ChatSession) -> None: - """Cache a chat session without persisting to the database.""" - await _cache_session(session) + """Cache a chat session in Redis (without persisting to the database).""" + redis_key = _get_session_cache_key(session.session_id) + async_redis = await get_redis_async() + await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) async def invalidate_session_cache(session_id: str) -> None: @@ -370,77 +339,6 @@ async def invalidate_session_cache(session_id: str) -> None: logger.warning(f"Failed to invalidate session cache for {session_id}: {e}") -async def _get_session_from_db(session_id: str) -> ChatSession | None: - """Get a chat session from the database.""" - prisma_session = await chat_db.get_chat_session(session_id) - if not prisma_session: - return None - - messages = prisma_session.Messages - logger.debug( - f"[DB] Loaded session {session_id}: {len(messages) if messages else 0} messages, " - f"roles={[m.role for m in messages[-3:]] if messages else []}" # Last 3 roles - ) - - return ChatSession.from_db(prisma_session, messages) - - -async def _save_session_to_db( - session: ChatSession, existing_message_count: int -) -> None: - """Save or update a chat session in the database.""" - # Check if session exists in DB - existing = await chat_db.get_chat_session(session.session_id) - - if not existing: - # Create new session - await chat_db.create_chat_session( - session_id=session.session_id, - user_id=session.user_id, - ) - existing_message_count = 0 - - # Calculate total tokens from usage - total_prompt = sum(u.prompt_tokens for u in session.usage) - total_completion = sum(u.completion_tokens for u in session.usage) - - # Update session metadata - await chat_db.update_chat_session( - session_id=session.session_id, - credentials=session.credentials, - successful_agent_runs=session.successful_agent_runs, - successful_agent_schedules=session.successful_agent_schedules, - total_prompt_tokens=total_prompt, - total_completion_tokens=total_completion, - ) - - # Add new messages (only those after existing count) - new_messages = session.messages[existing_message_count:] - if new_messages: - messages_data = [] - for msg in new_messages: - messages_data.append( - { - "role": msg.role, - "content": msg.content, - "name": msg.name, - "tool_call_id": msg.tool_call_id, - "refusal": msg.refusal, - "tool_calls": msg.tool_calls, - "function_call": msg.function_call, - } - ) - logger.debug( - f"[DB] Saving {len(new_messages)} messages to session {session.session_id}, " - f"roles={[m['role'] for m in messages_data]}" - ) - await chat_db.add_chat_messages_batch( - session_id=session.session_id, - messages=messages_data, - start_sequence=existing_message_count, - ) - - async def get_chat_session( session_id: str, user_id: str | None = None, @@ -488,16 +386,53 @@ async def get_chat_session( # Cache the session from DB try: - await _cache_session(session) + await cache_chat_session(session) + logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") return session -async def upsert_chat_session( - session: ChatSession, -) -> ChatSession: +async def _get_session_from_cache(session_id: str) -> ChatSession | None: + """Get a chat session from Redis cache.""" + redis_key = _get_session_cache_key(session_id) + async_redis = await get_redis_async() + raw_session: bytes | None = await async_redis.get(redis_key) + + if raw_session is None: + return None + + try: + session = ChatSession.model_validate_json(raw_session) + logger.info( + f"Loading session {session_id} from cache: " + f"message_count={len(session.messages)}, " + f"roles={[m.role for m in session.messages]}" + ) + return session + except Exception as e: + logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True) + raise RedisError(f"Corrupted session data for {session_id}") from e + + +async def _get_session_from_db(session_id: str) -> ChatSession | None: + """Get a chat session from the database.""" + session = await chat_db().get_chat_session(session_id) + if not session: + return None + + logger.info( + f"Loaded session {session_id} from DB: " + f"has_messages={bool(session.messages)}, " + f"message_count={len(session.messages)}, " + f"roles={[m.role for m in session.messages]}" + ) + + return session + + +async def upsert_chat_session(session: ChatSession) -> ChatSession: """Update a chat session in both cache and database. Uses session-level locking to prevent race conditions when concurrent @@ -515,7 +450,7 @@ async def upsert_chat_session( async with lock: # Get existing message count from DB for incremental saves - existing_message_count = await chat_db.get_chat_session_message_count( + existing_message_count = await chat_db().get_chat_session_message_count( session.session_id ) @@ -532,7 +467,7 @@ async def upsert_chat_session( # Save to cache (best-effort, even if DB failed) try: - await _cache_session(session) + await cache_chat_session(session) except Exception as e: # If DB succeeded but cache failed, raise cache error if db_error is None: @@ -553,6 +488,65 @@ async def upsert_chat_session( return session +async def _save_session_to_db( + session: ChatSession, existing_message_count: int +) -> None: + """Save or update a chat session in the database.""" + db = chat_db() + + # Check if session exists in DB + existing = await db.get_chat_session(session.session_id) + + if not existing: + # Create new session + await db.create_chat_session( + session_id=session.session_id, + user_id=session.user_id, + ) + existing_message_count = 0 + + # Calculate total tokens from usage + total_prompt = sum(u.prompt_tokens for u in session.usage) + total_completion = sum(u.completion_tokens for u in session.usage) + + # Update session metadata + await db.update_chat_session( + session_id=session.session_id, + credentials=session.credentials, + successful_agent_runs=session.successful_agent_runs, + successful_agent_schedules=session.successful_agent_schedules, + total_prompt_tokens=total_prompt, + total_completion_tokens=total_completion, + ) + + # Add new messages (only those after existing count) + new_messages = session.messages[existing_message_count:] + if new_messages: + messages_data = [] + for msg in new_messages: + messages_data.append( + { + "role": msg.role, + "content": msg.content, + "name": msg.name, + "tool_call_id": msg.tool_call_id, + "refusal": msg.refusal, + "tool_calls": msg.tool_calls, + "function_call": msg.function_call, + } + ) + logger.info( + f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " + f"roles={[m['role'] for m in messages_data]}, " + f"start_sequence={existing_message_count}" + ) + await db.add_chat_messages_batch( + session_id=session.session_id, + messages=messages_data, + start_sequence=existing_message_count, + ) + + async def append_and_save_message(session_id: str, message: ChatMessage) -> ChatSession: """Atomically append a message to a session and persist it. @@ -568,7 +562,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat raise ValueError(f"Session {session_id} not found") session.messages.append(message) - existing_message_count = await chat_db.get_chat_session_message_count( + existing_message_count = await chat_db().get_chat_session_message_count( session_id ) @@ -580,7 +574,7 @@ async def append_and_save_message(session_id: str, message: ChatMessage) -> Chat ) from e try: - await _cache_session(session) + await cache_chat_session(session) except Exception as e: logger.warning(f"Cache write failed for session {session_id}: {e}") @@ -599,7 +593,7 @@ async def create_chat_session(user_id: str) -> ChatSession: # Create in database first - fail fast if this fails try: - await chat_db.create_chat_session( + await chat_db().create_chat_session( session_id=session.session_id, user_id=user_id, ) @@ -611,7 +605,7 @@ async def create_chat_session(user_id: str) -> ChatSession: # Cache the session (best-effort optimization, DB is source of truth) try: - await _cache_session(session) + await cache_chat_session(session) except Exception as e: logger.warning(f"Failed to cache new session {session.session_id}: {e}") @@ -622,20 +616,16 @@ async def get_user_sessions( user_id: str, limit: int = 50, offset: int = 0, -) -> tuple[list[ChatSession], int]: +) -> tuple[list[ChatSessionInfo], int]: """Get chat sessions for a user from the database with total count. Returns: A tuple of (sessions, total_count) where total_count is the overall number of sessions for the user (not just the current page). """ - prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset) - total_count = await chat_db.get_user_session_count(user_id) - - sessions = [] - for prisma_session in prisma_sessions: - # Convert without messages for listing (lighter weight) - sessions.append(ChatSession.from_db(prisma_session, None)) + db = chat_db() + sessions = await db.get_user_chat_sessions(user_id, limit, offset) + total_count = await db.get_user_session_count(user_id) return sessions, total_count @@ -653,7 +643,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo """ # Delete from database first (with optional user_id validation) # This confirms ownership before invalidating cache - deleted = await chat_db.delete_chat_session(session_id, user_id) + deleted = await chat_db().delete_chat_session(session_id, user_id) if not deleted: return False @@ -688,7 +678,7 @@ async def update_session_title(session_id: str, title: str) -> bool: True if updated successfully, False otherwise. """ try: - result = await chat_db.update_chat_session(session_id=session_id, title=title) + result = await chat_db().update_chat_session(session_id=session_id, title=title) if result is None: logger.warning(f"Session {session_id} not found for title update") return False @@ -700,7 +690,7 @@ async def update_session_title(session_id: str, title: str) -> bool: cached = await _get_session_from_cache(session_id) if cached: cached.title = title - await _cache_session(cached) + await cache_chat_session(cached) except Exception as e: # Not critical - title will be correct on next full cache refresh logger.warning( @@ -711,3 +701,29 @@ async def update_session_title(session_id: str, title: str) -> bool: except Exception as e: logger.error(f"Failed to update title for session {session_id}: {e}") return False + + +# ==================== Chat session locks ==================== # + +_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() +_session_locks_mutex = asyncio.Lock() + + +async def _get_session_lock(session_id: str) -> asyncio.Lock: + """Get or create a lock for a specific session to prevent concurrent upserts. + + This was originally added to solve the specific problem of race conditions between + the session title thread and the conversation thread, which always occurs on the + same instance as we prevent rapid request sends on the frontend. + + Uses WeakValueDictionary for automatic cleanup: locks are garbage collected + when no coroutine holds a reference to them, preventing memory leaks from + unbounded growth of session locks. Explicit cleanup also occurs + in `delete_chat_session()`. + """ + async with _session_locks_mutex: + lock = _session_locks.get(session_id) + if lock is None: + lock = asyncio.Lock() + _session_locks[session_id] = lock + return lock diff --git a/autogpt_platform/backend/backend/api/features/chat/model_test.py b/autogpt_platform/backend/backend/copilot/model_test.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/model_test.py rename to autogpt_platform/backend/backend/copilot/model_test.py diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/copilot/response_model.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/response_model.py rename to autogpt_platform/backend/backend/copilot/response_model.py diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py b/autogpt_platform/backend/backend/copilot/sdk/__init__.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/sdk/__init__.py rename to autogpt_platform/backend/backend/copilot/sdk/__init__.py diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py similarity index 97% rename from autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py rename to autogpt_platform/backend/backend/copilot/sdk/response_adapter.py index f7151f8319..7a3976ae42 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter.py +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py @@ -20,7 +20,7 @@ from claude_agent_sdk import ( UserMessage, ) -from backend.api.features.chat.response_model import ( +from backend.copilot.response_model import ( StreamBaseResponse, StreamError, StreamFinish, @@ -34,10 +34,8 @@ from backend.api.features.chat.response_model import ( StreamToolInputStart, StreamToolOutputAvailable, ) -from backend.api.features.chat.sdk.tool_adapter import ( - MCP_TOOL_PREFIX, - pop_pending_tool_output, -) + +from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py rename to autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py index a4f2502642..7555eb8046 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/response_adapter_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py @@ -10,7 +10,7 @@ from claude_agent_sdk import ( UserMessage, ) -from backend.api.features.chat.response_model import ( +from backend.copilot.response_model import ( StreamBaseResponse, StreamError, StreamFinish, diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py rename to autogpt_platform/backend/backend/copilot/sdk/security_hooks.py index 89853402a3..7bae54e38d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks.py +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks.py @@ -11,7 +11,7 @@ import re from collections.abc import Callable from typing import Any, cast -from backend.api.features.chat.sdk.tool_adapter import ( +from .tool_adapter import ( BLOCKED_TOOLS, DANGEROUS_PATTERNS, MCP_TOOL_PREFIX, diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py b/autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py similarity index 83% rename from autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py rename to autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py index 2d09afdab7..e1891cf1bd 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/security_hooks_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/security_hooks_test.py @@ -1,4 +1,9 @@ -"""Unit tests for SDK security hooks.""" +"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. + +These are pure unit tests with no external dependencies (no SDK, no DB, no server). +They validate that the security hooks correctly block unauthorized paths, +tool access, and dangerous input patterns. +""" import os @@ -12,6 +17,10 @@ def _is_denied(result: dict) -> bool: return hook.get("permissionDecision") == "deny" +def _reason(result: dict) -> str: + return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") + + # -- Blocked tools ----------------------------------------------------------- @@ -163,3 +172,19 @@ def test_non_workspace_tool_passes_isolation(): "find_agent", {"query": "email"}, user_id="user-1" ) assert result == {} + + +# -- Deny message quality ---------------------------------------------------- + + +def test_blocked_tool_message_clarity(): + """Deny messages must include [SECURITY] and 'cannot be bypassed'.""" + reason = _reason(_validate_tool_access("bash", {})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason + + +def test_bash_builtin_blocked_message_clarity(): + reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) + assert "[SECURITY]" in reason + assert "cannot be bypassed" in reason diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py similarity index 98% rename from autogpt_platform/backend/backend/api/features/chat/sdk/service.py rename to autogpt_platform/backend/backend/copilot/sdk/service.py index 65c4cebb06..076e7b5743 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -255,7 +255,7 @@ def _build_sdk_env() -> dict[str, str]: def _make_sdk_cwd(session_id: str) -> str: """Create a safe, session-specific working directory path. - Delegates to :func:`~backend.api.features.chat.tools.sandbox.make_session_path` + Delegates to :func:`~backend.copilot.tools.sandbox.make_session_path` (single source of truth for path sanitization) and adds a defence-in-depth assertion. """ @@ -440,12 +440,16 @@ async def stream_chat_completion_sdk( f"Session {session_id} not found. Please create a new session first." ) - if message: - session.messages.append( - ChatMessage( - role="user" if is_user_message else "assistant", content=message - ) + # Append the new message to the session if it's not already there + new_message_role = "user" if is_user_message else "assistant" + if message and ( + len(session.messages) == 0 + or not ( + session.messages[-1].role == new_message_role + and session.messages[-1].content == message ) + ): + session.messages.append(ChatMessage(role=new_message_role, content=message)) if is_user_message: track_user_message( user_id=user_id, session_id=session_id, message_length=len(message) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py similarity index 98% rename from autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py rename to autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py index 2d259730bf..68364e7797 100644 --- a/autogpt_platform/backend/backend/api/features/chat/sdk/tool_adapter.py +++ b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py @@ -18,9 +18,9 @@ from collections.abc import Awaitable, Callable from contextvars import ContextVar from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools import TOOL_REGISTRY -from backend.api.features.chat.tools.base import BaseTool +from backend.copilot.model import ChatSession +from backend.copilot.tools import TOOL_REGISTRY +from backend.copilot.tools.base import BaseTool logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/sdk/transcript.py rename to autogpt_platform/backend/backend/copilot/sdk/transcript.py diff --git a/autogpt_platform/backend/test/chat/test_transcript.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py similarity index 99% rename from autogpt_platform/backend/test/chat/test_transcript.py rename to autogpt_platform/backend/backend/copilot/sdk/transcript_test.py index 71b1fad81f..b4b65fd526 100644 --- a/autogpt_platform/backend/test/chat/test_transcript.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -3,7 +3,7 @@ import json import os -from backend.api.features.chat.sdk.transcript import ( +from .transcript import ( STRIPPABLE_TYPES, read_transcript_file, strip_progress_entries, diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/copilot/service.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/service.py rename to autogpt_platform/backend/backend/copilot/service.py index cb5591e6d0..4dc5f05c25 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/copilot/service.py @@ -27,20 +27,18 @@ from openai.types.chat import ( ChatCompletionToolParam, ) +from backend.data.db_accessors import chat_db, understanding_db from backend.data.redis_client import get_redis_async -from backend.data.understanding import ( - format_understanding_for_prompt, - get_business_understanding, -) +from backend.data.understanding import format_understanding_for_prompt from backend.util.exceptions import NotFoundError from backend.util.settings import AppEnvironment, Settings -from . import db as chat_db from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, ChatSession, + ChatSessionInfo, Usage, cache_chat_session, get_chat_session, @@ -263,7 +261,7 @@ async def _build_system_prompt( understanding = None if user_id: try: - understanding = await get_business_understanding(user_id) + understanding = await understanding_db().get_business_understanding(user_id) except Exception as e: logger.warning(f"Failed to fetch business understanding: {e}") understanding = None @@ -339,7 +337,7 @@ async def _generate_session_title( async def assign_user_to_session( session_id: str, user_id: str, -) -> ChatSession: +) -> ChatSessionInfo: """ Assign a user to a chat session. """ @@ -428,12 +426,16 @@ async def stream_chat_completion( f"Session {session_id} not found. Please create a new session first." ) - if message: - session.messages.append( - ChatMessage( - role="user" if is_user_message else "assistant", content=message - ) + # Append the new message to the session if it's not already there + new_message_role = "user" if is_user_message else "assistant" + if message and ( + len(session.messages) == 0 + or not ( + session.messages[-1].role == new_message_role + and session.messages[-1].content == message ) + ): + session.messages.append(ChatMessage(role=new_message_role, content=message)) logger.info( f"Appended message (role={'user' if is_user_message else 'assistant'}), " f"new message_count={len(session.messages)}" @@ -1770,7 +1772,7 @@ async def _update_pending_operation( This is called by background tasks when long-running operations complete. """ # Update the message in database - updated = await chat_db.update_tool_message_content( + updated = await chat_db().update_tool_message_content( session_id=session_id, tool_call_id=tool_call_id, new_content=result, diff --git a/autogpt_platform/backend/backend/api/features/chat/service_test.py b/autogpt_platform/backend/backend/copilot/service_test.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/service_test.py rename to autogpt_platform/backend/backend/copilot/service_test.py diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/stream_registry.py rename to autogpt_platform/backend/backend/copilot/stream_registry.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md b/autogpt_platform/backend/backend/copilot/tools/IDEAS.md similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/IDEAS.md rename to autogpt_platform/backend/backend/copilot/tools/IDEAS.md diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/copilot/tools/__init__.py similarity index 94% rename from autogpt_platform/backend/backend/api/features/chat/tools/__init__.py rename to autogpt_platform/backend/backend/copilot/tools/__init__.py index 1ab4f720bb..0593fe69c0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/copilot/tools/__init__.py @@ -3,8 +3,8 @@ from typing import TYPE_CHECKING, Any from openai.types.chat import ChatCompletionToolParam -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tracking import track_tool_called +from backend.copilot.model import ChatSession +from backend.copilot.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool @@ -31,7 +31,7 @@ from .workspace_files import ( ) if TYPE_CHECKING: - from backend.api.features.chat.response_model import StreamToolOutputAvailable + from backend.copilot.response_model import StreamToolOutputAvailable logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/_test_data.py b/autogpt_platform/backend/backend/copilot/tools/_test_data.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/tools/_test_data.py rename to autogpt_platform/backend/backend/copilot/tools/_test_data.py index a8f208ebb0..c1d45d00df 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/_test_data.py +++ b/autogpt_platform/backend/backend/copilot/tools/_test_data.py @@ -6,11 +6,11 @@ import pytest from prisma.types import ProfileCreateInput from pydantic import SecretStr -from backend.api.features.chat.model import ChatSession from backend.api.features.store import db as store_db from backend.blocks.firecrawl.scrape import FirecrawlScrapeBlock from backend.blocks.io import AgentInputBlock, AgentOutputBlock from backend.blocks.llm import AITextGeneratorBlock +from backend.copilot.model import ChatSession from backend.data.db import prisma from backend.data.graph import Graph, Link, Node, create_graph from backend.data.model import APIKeyCredentials diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py b/autogpt_platform/backend/backend/copilot/tools/add_understanding.py similarity index 93% rename from autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py rename to autogpt_platform/backend/backend/copilot/tools/add_understanding.py index fe3d5e8984..b3291c5b0e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/add_understanding.py +++ b/autogpt_platform/backend/backend/copilot/tools/add_understanding.py @@ -3,11 +3,9 @@ import logging from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.data.understanding import ( - BusinessUnderstandingInput, - upsert_business_understanding, -) +from backend.copilot.model import ChatSession +from backend.data.db_accessors import understanding_db +from backend.data.understanding import BusinessUnderstandingInput from .base import BaseTool from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse @@ -99,7 +97,9 @@ and automations for the user's specific needs.""" ] # Upsert with merge - understanding = await upsert_business_understanding(user_id, input_data) + understanding = await understanding_db().upsert_business_understanding( + user_id, input_data + ) # Build current understanding summary (filter out empty values) current_understanding = { diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/__init__.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py rename to autogpt_platform/backend/backend/copilot/tools/agent_generator/__init__.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py similarity index 96% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py rename to autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py index f83ca30b5c..8fbe267f74 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py @@ -5,9 +5,8 @@ import re import uuid from typing import Any, NotRequired, TypedDict -from backend.api.features.library import db as library_db -from backend.api.features.store import db as store_db -from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs +from backend.data.db_accessors import graph_db, library_db, store_db +from backend.data.graph import Graph, Link, Node from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( @@ -145,8 +144,9 @@ async def get_library_agent_by_id( Returns: LibraryAgentSummary if found, None otherwise """ + db = library_db() try: - agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + agent = await db.get_library_agent_by_graph_id(user_id, agent_id) if agent: logger.debug(f"Found library agent by graph_id: {agent.name}") return LibraryAgentSummary( @@ -163,7 +163,7 @@ async def get_library_agent_by_id( logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}") try: - agent = await library_db.get_library_agent(agent_id, user_id) + agent = await db.get_library_agent(agent_id, user_id) if agent: logger.debug(f"Found library agent by library_id: {agent.name}") return LibraryAgentSummary( @@ -215,7 +215,7 @@ async def get_library_agents_for_generation( List of LibraryAgentSummary with schemas and recent executions for sub-agent composition """ try: - response = await library_db.list_library_agents( + response = await library_db().list_library_agents( user_id=user_id, search_term=search_query, page=1, @@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation( List of LibraryAgentSummary with full input/output schemas """ try: - response = await store_db.get_store_agents( + response = await store_db().get_store_agents( search_query=search_query, page=1, page_size=max_results, @@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation( return [] graph_ids = [agent.agent_graph_id for agent in agents_with_graphs] - graphs = await get_store_listed_graphs(*graph_ids) + graphs = await graph_db().get_store_listed_graphs(graph_ids) results: list[LibraryAgentSummary] = [] for agent in agents_with_graphs: @@ -673,9 +673,10 @@ async def save_agent_to_library( Tuple of (created Graph, LibraryAgent) """ graph = json_to_graph(agent_json) + db = library_db() if is_update: - return await library_db.update_graph_in_library(graph, user_id) - return await library_db.create_graph_in_library(graph, user_id) + return await db.update_graph_in_library(graph, user_id) + return await db.create_graph_in_library(graph, user_id) def graph_to_json(graph: Graph) -> dict[str, Any]: @@ -735,12 +736,14 @@ async def get_agent_as_json( Returns: Agent as JSON dict or None if not found """ - graph = await get_graph(agent_id, version=None, user_id=user_id) + db = graph_db() + + graph = await db.get_graph(agent_id, version=None, user_id=user_id) if not graph and user_id: try: - library_agent = await library_db.get_library_agent(agent_id, user_id) - graph = await get_graph( + library_agent = await library_db().get_library_agent(agent_id, user_id) + graph = await db.get_graph( library_agent.graph_id, version=None, user_id=user_id ) except NotFoundError: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/dummy.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/dummy.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/dummy.py rename to autogpt_platform/backend/backend/copilot/tools/agent_generator/dummy.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/errors.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py rename to autogpt_platform/backend/backend/copilot/tools/agent_generator/errors.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/service.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py rename to autogpt_platform/backend/backend/copilot/tools/agent_generator/service.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py b/autogpt_platform/backend/backend/copilot/tools/agent_output.py similarity index 94% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py rename to autogpt_platform/backend/backend/copilot/tools/agent_output.py index 457e4a4f9b..fe4767d09e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_output.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_output.py @@ -7,10 +7,9 @@ from typing import Any from pydantic import BaseModel, field_validator -from backend.api.features.chat.model import ChatSession -from backend.api.features.library import db as library_db from backend.api.features.library.model import LibraryAgent -from backend.data import execution as execution_db +from backend.copilot.model import ChatSession +from backend.data.db_accessors import execution_db, library_db from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta from .base import BaseTool @@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool): Resolve agent from provided identifiers. Returns (library_agent, error_message). """ + lib_db = library_db() + # Priority 1: Exact library agent ID if library_agent_id: try: - agent = await library_db.get_library_agent(library_agent_id, user_id) + agent = await lib_db.get_library_agent(library_agent_id, user_id) return agent, None except Exception as e: logger.warning(f"Failed to get library agent by ID: {e}") @@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool): return None, f"Agent '{store_slug}' not found in marketplace" # Find in user's library by graph_id - agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id) + agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id) if not agent: return ( None, @@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool): # Priority 3: Fuzzy name search in library if agent_name: try: - response = await library_db.list_library_agents( + response = await lib_db.list_library_agents( user_id=user_id, search_term=agent_name, page_size=5, @@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool): Fetch execution(s) based on filters. Returns (single_execution, available_executions_meta, error_message). """ + exec_db = execution_db() + # If specific execution_id provided, fetch it directly if execution_id: - execution = await execution_db.get_graph_execution( + execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=execution_id, include_node_executions=False, @@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool): return execution, [], None # Get completed executions with time filters - executions = await execution_db.get_graph_executions( + executions = await exec_db.get_graph_executions( graph_id=graph_id, user_id=user_id, statuses=[ExecutionStatus.COMPLETED], @@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool): # If only one execution, fetch full details if len(executions) == 1: - full_execution = await execution_db.get_graph_execution( + full_execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=executions[0].id, include_node_executions=False, @@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool): return full_execution, [], None # Multiple executions - return latest with full details, plus list of available - full_execution = await execution_db.get_graph_execution( + full_execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=executions[0].id, include_node_executions=False, @@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool): and not input_data.store_slug ): # Fetch execution directly to get graph_id - execution = await execution_db.get_graph_execution( + execution = await execution_db().get_graph_execution( user_id=user_id, execution_id=input_data.execution_id, include_node_executions=False, @@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool): ) # Find library agent by graph_id - agent = await library_db.get_library_agent_by_graph_id( + agent = await library_db().get_library_agent_by_graph_id( user_id, execution.graph_id ) if not agent: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py b/autogpt_platform/backend/backend/copilot/tools/agent_search.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py rename to autogpt_platform/backend/backend/copilot/tools/agent_search.py index 61cdba1ef9..3c380a7150 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_search.py @@ -4,8 +4,7 @@ import logging import re from typing import Literal -from backend.api.features.library import db as library_db -from backend.api.features.store import db as store_db +from backend.data.db_accessors import library_db, store_db from backend.util.exceptions import DatabaseError, NotFoundError from .models import ( @@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N Returns: AgentInfo if found, None otherwise """ + lib_db = library_db() + try: - agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id) if agent: logger.debug(f"Found library agent by graph_id: {agent.name}") return AgentInfo( @@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N ) try: - agent = await library_db.get_library_agent(agent_id, user_id) + agent = await lib_db.get_library_agent(agent_id, user_id) if agent: logger.debug(f"Found library agent by library_id: {agent.name}") return AgentInfo( @@ -133,7 +134,7 @@ async def search_agents( try: if source == "marketplace": logger.info(f"Searching marketplace for: {query}") - results = await store_db.get_store_agents(search_query=query, page_size=5) + results = await store_db().get_store_agents(search_query=query, page_size=5) for agent in results.agents: agents.append( AgentInfo( @@ -159,7 +160,7 @@ async def search_agents( if not agents: logger.info(f"Searching user library for: {query}") - results = await library_db.list_library_agents( + results = await library_db().list_library_agents( user_id=user_id, # type: ignore[arg-type] search_term=query, page_size=10, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/base.py b/autogpt_platform/backend/backend/copilot/tools/base.py similarity index 96% rename from autogpt_platform/backend/backend/api/features/chat/tools/base.py rename to autogpt_platform/backend/backend/copilot/tools/base.py index 809e06632b..e821b1844f 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/base.py +++ b/autogpt_platform/backend/backend/copilot/tools/base.py @@ -5,8 +5,8 @@ from typing import Any from openai.types.chat import ChatCompletionToolParam -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.response_model import StreamToolOutputAvailable +from backend.copilot.model import ChatSession +from backend.copilot.response_model import StreamToolOutputAvailable from .models import ErrorResponse, NeedLoginResponse, ToolResponseBase diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py similarity index 92% rename from autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py rename to autogpt_platform/backend/backend/copilot/tools/bash_exec.py index da9d8bf3fa..6e32a3c720 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/bash_exec.py +++ b/autogpt_platform/backend/backend/copilot/tools/bash_exec.py @@ -11,18 +11,11 @@ available (e.g. macOS development). import logging from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( - BashExecResponse, - ErrorResponse, - ToolResponseBase, -) -from backend.api.features.chat.tools.sandbox import ( - get_workspace_dir, - has_full_sandbox, - run_sandboxed, -) +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import BashExecResponse, ErrorResponse, ToolResponseBase +from .sandbox import get_workspace_dir, has_full_sandbox, run_sandboxed logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py b/autogpt_platform/backend/backend/copilot/tools/check_operation_status.py similarity index 93% rename from autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py rename to autogpt_platform/backend/backend/copilot/tools/check_operation_status.py index b8ec770fd0..a03fe074ba 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/check_operation_status.py +++ b/autogpt_platform/backend/backend/copilot/tools/check_operation_status.py @@ -3,13 +3,10 @@ import logging from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( - ErrorResponse, - ResponseType, - ToolResponseBase, -) +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import ErrorResponse, ResponseType, ToolResponseBase logger = logging.getLogger(__name__) @@ -78,7 +75,7 @@ class CheckOperationStatusTool(BaseTool): session: ChatSession, **kwargs, ) -> ToolResponseBase: - from backend.api.features.chat import stream_registry + from backend.copilot import stream_registry operation_id = (kwargs.get("operation_id") or "").strip() task_id = (kwargs.get("task_id") or "").strip() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/copilot/tools/create_agent.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py rename to autogpt_platform/backend/backend/copilot/tools/create_agent.py index 7333851a5b..74d6227b09 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/create_agent.py @@ -3,7 +3,7 @@ import logging from typing import Any -from backend.api.features.chat.model import ChatSession +from backend.copilot.model import ChatSession from .agent_generator import ( AgentGeneratorNotConfiguredError, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/copilot/tools/customize_agent.py similarity index 98% rename from autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py rename to autogpt_platform/backend/backend/copilot/tools/customize_agent.py index c0568bd936..96e19656c6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/customize_agent.py @@ -3,9 +3,9 @@ import logging from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.api.features.store import db as store_db from backend.api.features.store.exceptions import AgentNotFoundError +from backend.copilot.model import ChatSession +from backend.data.db_accessors import store_db as get_store_db from .agent_generator import ( AgentGeneratorNotConfiguredError, @@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool): creator_username, agent_slug = parts + store_db = get_store_db() + # Fetch the marketplace agent details try: agent_details = await store_db.get_store_agent_details( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/copilot/tools/edit_agent.py similarity index 99% rename from autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py rename to autogpt_platform/backend/backend/copilot/tools/edit_agent.py index 3ae56407a7..14d3a8d8f9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/edit_agent.py @@ -3,7 +3,7 @@ import logging from typing import Any -from backend.api.features.chat.model import ChatSession +from backend.copilot.model import ChatSession from .agent_generator import ( AgentGeneratorNotConfiguredError, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py similarity index 97% rename from autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py rename to autogpt_platform/backend/backend/copilot/tools/feature_requests.py index 8346df1177..2c9e8cd017 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests.py +++ b/autogpt_platform/backend/backend/copilot/tools/feature_requests.py @@ -5,9 +5,14 @@ from typing import Any from pydantic import SecretStr -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( +from backend.blocks.linear._api import LinearClient +from backend.copilot.model import ChatSession +from backend.data.db_accessors import user_db +from backend.data.model import APIKeyCredentials +from backend.util.settings import Settings + +from .base import BaseTool +from .models import ( ErrorResponse, FeatureRequestCreatedResponse, FeatureRequestInfo, @@ -15,10 +20,6 @@ from backend.api.features.chat.tools.models import ( NoResultsResponse, ToolResponseBase, ) -from backend.blocks.linear._api import LinearClient -from backend.data.model import APIKeyCredentials -from backend.data.user import get_user_email_by_id -from backend.util.settings import Settings logger = logging.getLogger(__name__) @@ -332,7 +333,9 @@ class CreateFeatureRequestTool(BaseTool): # Resolve a human-readable name (email) for the Linear customer record. # Fall back to user_id if the lookup fails or returns None. try: - customer_display_name = await get_user_email_by_id(user_id) or user_id + customer_display_name = ( + await user_db().get_user_email_by_id(user_id) or user_id + ) except Exception: customer_display_name = user_id diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py b/autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py similarity index 97% rename from autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py rename to autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py index 438725368f..a24eb4de22 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/feature_requests_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/feature_requests_test.py @@ -1,22 +1,18 @@ """Tests for SearchFeatureRequestsTool and CreateFeatureRequestTool.""" -from unittest.mock import AsyncMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.feature_requests import ( - CreateFeatureRequestTool, - SearchFeatureRequestsTool, -) -from backend.api.features.chat.tools.models import ( +from ._test_data import make_session +from .feature_requests import CreateFeatureRequestTool, SearchFeatureRequestsTool +from .models import ( ErrorResponse, FeatureRequestCreatedResponse, FeatureRequestSearchResponse, NoResultsResponse, ) -from ._test_data import make_session - _TEST_USER_ID = "test-user-feature-requests" _TEST_USER_EMAIL = "testuser@example.com" @@ -39,7 +35,7 @@ def _mock_linear_config(*, query_return=None, mutate_return=None): client.mutate.return_value = mutate_return return ( patch( - "backend.api.features.chat.tools.feature_requests._get_linear_config", + "backend.copilot.tools.feature_requests._get_linear_config", return_value=(client, _FAKE_PROJECT_ID, _FAKE_TEAM_ID), ), client, @@ -208,7 +204,7 @@ class TestSearchFeatureRequestsTool: async def test_linear_client_init_failure(self): session = make_session(user_id=_TEST_USER_ID) with patch( - "backend.api.features.chat.tools.feature_requests._get_linear_config", + "backend.copilot.tools.feature_requests._get_linear_config", side_effect=RuntimeError("No API key"), ): tool = SearchFeatureRequestsTool() @@ -231,10 +227,11 @@ class TestCreateFeatureRequestTool: @pytest.fixture(autouse=True) def _patch_email_lookup(self): + mock_user_db = MagicMock() + mock_user_db.get_user_email_by_id = AsyncMock(return_value=_TEST_USER_EMAIL) with patch( - "backend.api.features.chat.tools.feature_requests.get_user_email_by_id", - new_callable=AsyncMock, - return_value=_TEST_USER_EMAIL, + "backend.copilot.tools.feature_requests.user_db", + return_value=mock_user_db, ): yield @@ -347,7 +344,7 @@ class TestCreateFeatureRequestTool: async def test_linear_client_init_failure(self): session = make_session(user_id=_TEST_USER_ID) with patch( - "backend.api.features.chat.tools.feature_requests._get_linear_config", + "backend.copilot.tools.feature_requests._get_linear_config", side_effect=RuntimeError("No API key"), ): tool = CreateFeatureRequestTool() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py b/autogpt_platform/backend/backend/copilot/tools/find_agent.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py rename to autogpt_platform/backend/backend/copilot/tools/find_agent.py index 477522757d..32e5bce454 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_agent.py @@ -2,7 +2,7 @@ from typing import Any -from backend.api.features.chat.model import ChatSession +from backend.copilot.model import ChatSession from .agent_search import search_agents from .base import BaseTool diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py b/autogpt_platform/backend/backend/copilot/tools/find_block.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/find_block.py rename to autogpt_platform/backend/backend/copilot/tools/find_block.py index c51317cb62..a3f784f3a8 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block.py @@ -3,17 +3,18 @@ from typing import Any from prisma.enums import ContentType -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool, ToolResponseBase -from backend.api.features.chat.tools.models import ( +from backend.blocks import get_block +from backend.blocks._base import BlockType +from backend.copilot.model import ChatSession +from backend.data.db_accessors import search + +from .base import BaseTool, ToolResponseBase +from .models import ( BlockInfoSummary, BlockListResponse, ErrorResponse, NoResultsResponse, ) -from backend.api.features.store.hybrid_search import unified_hybrid_search -from backend.blocks import get_block -from backend.blocks._base import BlockType logger = logging.getLogger(__name__) @@ -107,7 +108,7 @@ class FindBlockTool(BaseTool): try: # Search for blocks using hybrid search - results, total = await unified_hybrid_search( + results, total = await search().unified_hybrid_search( query=query, content_types=[ContentType.BLOCK], page=1, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py similarity index 93% rename from autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py rename to autogpt_platform/backend/backend/copilot/tools/find_block_test.py index 44606f81c3..ebd3c761ab 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block_test.py @@ -4,15 +4,15 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.find_block import ( +from backend.blocks._base import BlockType + +from ._test_data import make_session +from .find_block import ( COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES, FindBlockTool, ) -from backend.api.features.chat.tools.models import BlockListResponse -from backend.blocks._base import BlockType - -from ._test_data import make_session +from .models import BlockListResponse _TEST_USER_ID = "test-user-find-block" @@ -84,13 +84,17 @@ class TestFindBlockFiltering: "standard-block-id": standard_block, }.get(block_id) + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + with patch( - "backend.api.features.chat.tools.find_block.unified_hybrid_search", - new_callable=AsyncMock, - return_value=(search_results, 2), + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, ): with patch( - "backend.api.features.chat.tools.find_block.get_block", + "backend.copilot.tools.find_block.get_block", side_effect=mock_get_block, ): tool = FindBlockTool() @@ -128,13 +132,17 @@ class TestFindBlockFiltering: "normal-block-id": normal_block, }.get(block_id) + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, 2) + ) + with patch( - "backend.api.features.chat.tools.find_block.unified_hybrid_search", - new_callable=AsyncMock, - return_value=(search_results, 2), + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, ): with patch( - "backend.api.features.chat.tools.find_block.get_block", + "backend.copilot.tools.find_block.get_block", side_effect=mock_get_block, ): tool = FindBlockTool() @@ -353,12 +361,16 @@ class TestFindBlockFiltering: for d in block_defs } + mock_search_db = MagicMock() + mock_search_db.unified_hybrid_search = AsyncMock( + return_value=(search_results, len(search_results)) + ) + with patch( - "backend.api.features.chat.tools.find_block.unified_hybrid_search", - new_callable=AsyncMock, - return_value=(search_results, len(search_results)), + "backend.copilot.tools.find_block.search", + return_value=mock_search_db, ), patch( - "backend.api.features.chat.tools.find_block.get_block", + "backend.copilot.tools.find_block.get_block", side_effect=lambda bid: mock_blocks.get(bid), ): tool = FindBlockTool() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py b/autogpt_platform/backend/backend/copilot/tools/find_library_agent.py similarity index 96% rename from autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py rename to autogpt_platform/backend/backend/copilot/tools/find_library_agent.py index 108fba75ae..16ae90e40b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/find_library_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_library_agent.py @@ -2,7 +2,7 @@ from typing import Any -from backend.api.features.chat.model import ChatSession +from backend.copilot.model import ChatSession from .agent_search import search_agents from .base import BaseTool diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py b/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py rename to autogpt_platform/backend/backend/copilot/tools/get_doc_page.py index 7040cd7db5..87ec7225a5 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/get_doc_page.py +++ b/autogpt_platform/backend/backend/copilot/tools/get_doc_page.py @@ -4,13 +4,10 @@ import logging from pathlib import Path from typing import Any -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( - DocPageResponse, - ErrorResponse, - ToolResponseBase, -) +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import DocPageResponse, ErrorResponse, ToolResponseBase logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/helpers.py b/autogpt_platform/backend/backend/copilot/tools/helpers.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/helpers.py rename to autogpt_platform/backend/backend/copilot/tools/helpers.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/copilot/tools/models.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/models.py rename to autogpt_platform/backend/backend/copilot/tools/models.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/copilot/tools/run_agent.py similarity index 97% rename from autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py rename to autogpt_platform/backend/backend/copilot/tools/run_agent.py index a9f19bcf62..46af6fbcb0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_agent.py @@ -5,16 +5,12 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from backend.api.features.chat.config import ChatConfig -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tracking import ( - track_agent_run_success, - track_agent_scheduled, -) -from backend.api.features.library import db as library_db +from backend.copilot.config import ChatConfig +from backend.copilot.model import ChatSession +from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled +from backend.data.db_accessors import graph_db, library_db, user_db from backend.data.graph import GraphModel from backend.data.model import CredentialsMetaInput -from backend.data.user import get_user_by_id from backend.executor import utils as execution_utils from backend.util.clients import get_scheduler_client from backend.util.exceptions import DatabaseError, NotFoundError @@ -200,7 +196,7 @@ class RunAgentTool(BaseTool): # Priority: library_agent_id if provided if has_library_id: - library_agent = await library_db.get_library_agent( + library_agent = await library_db().get_library_agent( params.library_agent_id, user_id ) if not library_agent: @@ -209,9 +205,7 @@ class RunAgentTool(BaseTool): session_id=session_id, ) # Get the graph from the library agent - from backend.data.graph import get_graph - - graph = await get_graph( + graph = await graph_db().get_graph( library_agent.graph_id, library_agent.graph_version, user_id=user_id, @@ -522,7 +516,7 @@ class RunAgentTool(BaseTool): library_agent = await get_or_create_library_agent(graph, user_id) # Get user timezone - user = await get_user_by_id(user_id) + user = await user_db().get_user_by_id(user_id) user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone) # Create schedule diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/copilot/tools/run_agent_test.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py rename to autogpt_platform/backend/backend/copilot/tools/run_agent_test.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/copilot/tools/run_block.py similarity index 98% rename from autogpt_platform/backend/backend/api/features/chat/tools/run_block.py rename to autogpt_platform/backend/backend/copilot/tools/run_block.py index a55478326a..32f249626b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_block.py @@ -7,20 +7,17 @@ from typing import Any from pydantic_core import PydanticUndefined -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.find_block import ( - COPILOT_EXCLUDED_BLOCK_IDS, - COPILOT_EXCLUDED_BLOCK_TYPES, -) from backend.blocks import get_block from backend.blocks._base import AnyBlockSchema +from backend.copilot.model import ChatSession +from backend.data.db_accessors import workspace_db from backend.data.execution import ExecutionContext from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput -from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError from .base import BaseTool +from .find_block import COPILOT_EXCLUDED_BLOCK_IDS, COPILOT_EXCLUDED_BLOCK_TYPES from .helpers import get_inputs_from_schema from .models import ( BlockDetails, @@ -276,7 +273,7 @@ class RunBlockTool(BaseTool): try: # Get or create user's workspace for CoPilot file operations - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Generate synthetic IDs for CoPilot context # Each chat session is treated as its own agent with one continuous run diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py b/autogpt_platform/backend/backend/copilot/tools/run_block_test.py similarity index 93% rename from autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py rename to autogpt_platform/backend/backend/copilot/tools/run_block_test.py index 55efc38479..7ab4d706a2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_block_test.py @@ -4,16 +4,16 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.models import ( +from backend.blocks._base import BlockType + +from ._test_data import make_session +from .models import ( BlockDetailsResponse, BlockOutputResponse, ErrorResponse, InputValidationErrorResponse, ) -from backend.api.features.chat.tools.run_block import RunBlockTool -from backend.blocks._base import BlockType - -from ._test_data import make_session +from .run_block import RunBlockTool _TEST_USER_ID = "test-user-run-block" @@ -77,7 +77,7 @@ class TestRunBlockFiltering: input_block = make_mock_block("input-block-id", "Input Block", BlockType.INPUT) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=input_block, ): tool = RunBlockTool() @@ -103,7 +103,7 @@ class TestRunBlockFiltering: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=smart_block, ): tool = RunBlockTool() @@ -127,7 +127,7 @@ class TestRunBlockFiltering: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=standard_block, ): tool = RunBlockTool() @@ -183,7 +183,7 @@ class TestRunBlockInputValidation: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock_block, ): tool = RunBlockTool() @@ -222,7 +222,7 @@ class TestRunBlockInputValidation: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock_block, ): tool = RunBlockTool() @@ -263,7 +263,7 @@ class TestRunBlockInputValidation: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock_block, ): tool = RunBlockTool() @@ -302,15 +302,19 @@ class TestRunBlockInputValidation: mock_block.execute = mock_execute + mock_workspace_db = MagicMock() + mock_workspace_db.get_or_create_workspace = AsyncMock( + return_value=MagicMock(id="test-workspace-id") + ) + with ( patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock_block, ), patch( - "backend.api.features.chat.tools.run_block.get_or_create_workspace", - new_callable=AsyncMock, - return_value=MagicMock(id="test-workspace-id"), + "backend.copilot.tools.run_block.workspace_db", + return_value=mock_workspace_db, ), ): tool = RunBlockTool() @@ -344,7 +348,7 @@ class TestRunBlockInputValidation: ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock_block, ): tool = RunBlockTool() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py b/autogpt_platform/backend/backend/copilot/tools/sandbox.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tools/sandbox.py rename to autogpt_platform/backend/backend/copilot/tools/sandbox.py diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py b/autogpt_platform/backend/backend/copilot/tools/search_docs.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py rename to autogpt_platform/backend/backend/copilot/tools/search_docs.py index edb0c0de1e..b09fe64a2c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/search_docs.py +++ b/autogpt_platform/backend/backend/copilot/tools/search_docs.py @@ -5,16 +5,17 @@ from typing import Any from prisma.enums import ContentType -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( +from backend.copilot.model import ChatSession +from backend.data.db_accessors import search + +from .base import BaseTool +from .models import ( DocSearchResult, DocSearchResultsResponse, ErrorResponse, NoResultsResponse, ToolResponseBase, ) -from backend.api.features.store.hybrid_search import unified_hybrid_search logger = logging.getLogger(__name__) @@ -117,7 +118,7 @@ class SearchDocsTool(BaseTool): try: # Search using hybrid search for DOCUMENTATION content type only - results, total = await unified_hybrid_search( + results, total = await search().unified_hybrid_search( query=query, content_types=[ContentType.DOCUMENTATION], page=1, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py b/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py similarity index 95% rename from autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py rename to autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py index fbab0b723d..d06fbb766d 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/test_run_block_details.py +++ b/autogpt_platform/backend/backend/copilot/tools/test_run_block_details.py @@ -4,13 +4,13 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.models import BlockDetailsResponse -from backend.api.features.chat.tools.run_block import RunBlockTool from backend.blocks._base import BlockType from backend.data.model import CredentialsMetaInput from backend.integrations.providers import ProviderName from ._test_data import make_session +from .models import BlockDetailsResponse +from .run_block import RunBlockTool _TEST_USER_ID = "test-user-run-block-details" @@ -61,7 +61,7 @@ async def test_run_block_returns_details_when_no_input_provided(): ) with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=http_block, ): # Mock credentials check to return no missing credentials @@ -120,7 +120,7 @@ async def test_run_block_returns_details_when_only_credentials_provided(): } with patch( - "backend.api.features.chat.tools.run_block.get_block", + "backend.copilot.tools.run_block.get_block", return_value=mock, ): with patch.object( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/copilot/tools/utils.py similarity index 97% rename from autogpt_platform/backend/backend/api/features/chat/tools/utils.py rename to autogpt_platform/backend/backend/copilot/tools/utils.py index 3b2168d09e..60747566a6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/copilot/tools/utils.py @@ -3,9 +3,8 @@ import logging from typing import Any -from backend.api.features.library import db as library_db from backend.api.features.library import model as library_model -from backend.api.features.store import db as store_db +from backend.data.db_accessors import library_db, store_db from backend.data.graph import GraphModel from backend.data.model import ( Credentials, @@ -39,13 +38,14 @@ async def fetch_graph_from_store_slug( Raises: DatabaseError: If there's a database error during lookup. """ + sdb = store_db() try: - store_agent = await store_db.get_store_agent_details(username, agent_name) + store_agent = await sdb.get_store_agent_details(username, agent_name) except NotFoundError: return None, None # Get the graph from store listing version - graph = await store_db.get_available_graph( + graph = await sdb.get_available_graph( store_agent.store_listing_version_id, hide_nodes=False ) return graph, store_agent @@ -210,13 +210,13 @@ async def get_or_create_library_agent( Returns: LibraryAgent instance """ - existing = await library_db.get_library_agent_by_graph_id( + existing = await library_db().get_library_agent_by_graph_id( graph_id=graph.id, user_id=user_id ) if existing: return existing - library_agents = await library_db.create_library_agent( + library_agents = await library_db().create_library_agent( graph=graph, user_id=user_id, create_library_agents_for_sub_graphs=False, diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py b/autogpt_platform/backend/backend/copilot/tools/web_fetch.py similarity index 94% rename from autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py rename to autogpt_platform/backend/backend/copilot/tools/web_fetch.py index fed7cc11fa..78ee2f9fe0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/web_fetch.py +++ b/autogpt_platform/backend/backend/copilot/tools/web_fetch.py @@ -6,15 +6,12 @@ from typing import Any import aiohttp import html2text -from backend.api.features.chat.model import ChatSession -from backend.api.features.chat.tools.base import BaseTool -from backend.api.features.chat.tools.models import ( - ErrorResponse, - ToolResponseBase, - WebFetchResponse, -) +from backend.copilot.model import ChatSession from backend.util.request import Requests +from .base import BaseTool +from .models import ErrorResponse, ToolResponseBase, WebFetchResponse + logger = logging.getLogger(__name__) # Limits diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py similarity index 96% rename from autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py rename to autogpt_platform/backend/backend/copilot/tools/workspace_files.py index f37d2c80e0..2f0e225483 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py @@ -6,8 +6,8 @@ from typing import Any, Optional from pydantic import BaseModel -from backend.api.features.chat.model import ChatSession -from backend.data.workspace import get_or_create_workspace +from backend.copilot.model import ChatSession +from backend.data.db_accessors import workspace_db from backend.util.settings import Config from backend.util.virus_scanner import scan_content_safe from backend.util.workspace import WorkspaceManager @@ -148,7 +148,7 @@ class ListWorkspaceFilesTool(BaseTool): include_all_sessions: bool = kwargs.get("include_all_sessions", False) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -167,8 +167,8 @@ class ListWorkspaceFilesTool(BaseTool): file_id=f.id, name=f.name, path=f.path, - mime_type=f.mimeType, - size_bytes=f.sizeBytes, + mime_type=f.mime_type, + size_bytes=f.size_bytes, ) for f in files ] @@ -284,7 +284,7 @@ class ReadWorkspaceFileTool(BaseTool): ) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -309,8 +309,8 @@ class ReadWorkspaceFileTool(BaseTool): target_file_id = file_info.id # Decide whether to return inline content or metadata+URL - is_small_file = file_info.sizeBytes <= self.MAX_INLINE_SIZE_BYTES - is_text_file = self._is_text_mime_type(file_info.mimeType) + is_small_file = file_info.size_bytes <= self.MAX_INLINE_SIZE_BYTES + is_text_file = self._is_text_mime_type(file_info.mime_type) # Return inline content for small text files (unless force_download_url) if is_small_file and is_text_file and not force_download_url: @@ -321,7 +321,7 @@ class ReadWorkspaceFileTool(BaseTool): file_id=file_info.id, name=file_info.name, path=file_info.path, - mime_type=file_info.mimeType, + mime_type=file_info.mime_type, content_base64=content_b64, message=f"Successfully read file: {file_info.name}", session_id=session_id, @@ -350,11 +350,11 @@ class ReadWorkspaceFileTool(BaseTool): file_id=file_info.id, name=file_info.name, path=file_info.path, - mime_type=file_info.mimeType, - size_bytes=file_info.sizeBytes, + mime_type=file_info.mime_type, + size_bytes=file_info.size_bytes, download_url=download_url, preview=preview, - message=f"File: {file_info.name} ({file_info.sizeBytes} bytes). Use download_url to retrieve content.", + message=f"File: {file_info.name} ({file_info.size_bytes} bytes). Use download_url to retrieve content.", session_id=session_id, ) @@ -484,7 +484,7 @@ class WriteWorkspaceFileTool(BaseTool): # Virus scan await scan_content_safe(content, filename=filename) - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -500,7 +500,7 @@ class WriteWorkspaceFileTool(BaseTool): file_id=file_record.id, name=file_record.name, path=file_record.path, - size_bytes=file_record.sizeBytes, + size_bytes=file_record.size_bytes, message=f"Successfully wrote file: {file_record.name}", session_id=session_id, ) @@ -583,7 +583,7 @@ class DeleteWorkspaceFileTool(BaseTool): ) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) diff --git a/autogpt_platform/backend/backend/api/features/chat/tracking.py b/autogpt_platform/backend/backend/copilot/tracking.py similarity index 100% rename from autogpt_platform/backend/backend/api/features/chat/tracking.py rename to autogpt_platform/backend/backend/copilot/tracking.py diff --git a/autogpt_platform/backend/backend/data/db_accessors.py b/autogpt_platform/backend/backend/data/db_accessors.py new file mode 100644 index 0000000000..9875cabec5 --- /dev/null +++ b/autogpt_platform/backend/backend/data/db_accessors.py @@ -0,0 +1,118 @@ +from backend.data import db + + +def chat_db(): + if db.is_connected(): + from backend.copilot import db as _chat_db + + chat_db = _chat_db + else: + from backend.util.clients import get_database_manager_async_client + + chat_db = get_database_manager_async_client() + + return chat_db + + +def graph_db(): + if db.is_connected(): + from backend.data import graph as _graph_db + + graph_db = _graph_db + else: + from backend.util.clients import get_database_manager_async_client + + graph_db = get_database_manager_async_client() + + return graph_db + + +def library_db(): + if db.is_connected(): + from backend.api.features.library import db as _library_db + + library_db = _library_db + else: + from backend.util.clients import get_database_manager_async_client + + library_db = get_database_manager_async_client() + + return library_db + + +def store_db(): + if db.is_connected(): + from backend.api.features.store import db as _store_db + + store_db = _store_db + else: + from backend.util.clients import get_database_manager_async_client + + store_db = get_database_manager_async_client() + + return store_db + + +def search(): + if db.is_connected(): + from backend.api.features.store import hybrid_search as _search + + search = _search + else: + from backend.util.clients import get_database_manager_async_client + + search = get_database_manager_async_client() + + return search + + +def execution_db(): + if db.is_connected(): + from backend.data import execution as _execution_db + + execution_db = _execution_db + else: + from backend.util.clients import get_database_manager_async_client + + execution_db = get_database_manager_async_client() + + return execution_db + + +def user_db(): + if db.is_connected(): + from backend.data import user as _user_db + + user_db = _user_db + else: + from backend.util.clients import get_database_manager_async_client + + user_db = get_database_manager_async_client() + + return user_db + + +def understanding_db(): + if db.is_connected(): + from backend.data import understanding as _understanding_db + + understanding_db = _understanding_db + else: + from backend.util.clients import get_database_manager_async_client + + understanding_db = get_database_manager_async_client() + + return understanding_db + + +def workspace_db(): + if db.is_connected(): + from backend.data import workspace as _workspace_db + + workspace_db = _workspace_db + else: + from backend.util.clients import get_database_manager_async_client + + workspace_db = get_database_manager_async_client() + + return workspace_db diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/data/db_manager.py similarity index 72% rename from autogpt_platform/backend/backend/executor/database.py rename to autogpt_platform/backend/backend/data/db_manager.py index d44439d51c..090c21ad7c 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/data/db_manager.py @@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas from backend.api.features.library.db import ( add_store_agent_to_library, + create_graph_in_library, + create_library_agent, + get_library_agent, + get_library_agent_by_graph_id, list_library_agents, + update_graph_in_library, +) +from backend.api.features.store.db import ( + get_agent, + get_available_graph, + get_store_agent_details, + get_store_agents, ) -from backend.api.features.store.db import get_store_agent_details, get_store_agents from backend.api.features.store.embeddings import ( backfill_missing_embeddings, cleanup_orphaned_embeddings, get_embedding_stats, ) +from backend.api.features.store.hybrid_search import unified_hybrid_search +from backend.copilot import db as chat_db from backend.data import db from backend.data.analytics import ( get_accuracy_trends_and_alerts, @@ -48,6 +60,7 @@ from backend.data.graph import ( get_graph_metadata, get_graph_settings, get_node, + get_store_listed_graphs, validate_graph_execution_permissions, ) from backend.data.human_review import ( @@ -67,6 +80,10 @@ from backend.data.notifications import ( remove_notifications_from_batch, ) from backend.data.onboarding import increment_onboarding_runs +from backend.data.understanding import ( + get_business_understanding, + upsert_business_understanding, +) from backend.data.user import ( get_active_user_ids_in_timerange, get_user_by_id, @@ -76,6 +93,7 @@ from backend.data.user import ( get_user_notification_preference, update_user_integrations, ) +from backend.data.workspace import get_or_create_workspace from backend.util.service import ( AppService, AppServiceClient, @@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int: class DatabaseManager(AppService): + """Database connection pooling service. + + This service connects to the Prisma engine and exposes database + operations via RPC endpoints. It acts as a centralized connection pool + for all services that need database access. + """ + @asynccontextmanager async def lifespan(self, app: "FastAPI"): async with super().lifespan(app): @@ -142,11 +167,15 @@ class DatabaseManager(AppService): def _( f: Callable[P, R], name: str | None = None ) -> Callable[Concatenate[object, P], R]: + """ + Exposes a function as an RPC endpoint, and adds a virtual `self` param + to the function's type so it can be bound as a method. + """ if name is not None: f.__name__ = name return cast(Callable[Concatenate[object, P], R], expose(f)) - # Executions + # ============ Graph Executions ============ # get_child_graph_executions = _(get_child_graph_executions) get_graph_executions = _(get_graph_executions) get_graph_executions_count = _(get_graph_executions_count) @@ -170,36 +199,37 @@ class DatabaseManager(AppService): get_frequently_executed_graphs = _(get_frequently_executed_graphs) get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring) - # Graphs + # ============ Graphs ============ # get_node = _(get_node) get_graph = _(get_graph) get_connected_output_nodes = _(get_connected_output_nodes) get_graph_metadata = _(get_graph_metadata) get_graph_settings = _(get_graph_settings) + get_store_listed_graphs = _(get_store_listed_graphs) - # Credits + # ============ Credits ============ # spend_credits = _(_spend_credits, name="spend_credits") get_credits = _(_get_credits, name="get_credits") - # User + User Metadata + User Integrations + # ============ User + Integrations ============ # + get_user_by_id = _(get_user_by_id) get_user_integrations = _(get_user_integrations) update_user_integrations = _(update_user_integrations) - # User Comms - async + # ============ User Comms ============ # get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange) - get_user_by_id = _(get_user_by_id) get_user_email_by_id = _(get_user_email_by_id) get_user_email_verification = _(get_user_email_verification) get_user_notification_preference = _(get_user_notification_preference) - # Human In The Loop + # ============ Human In The Loop ============ # cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution) check_approval = _(check_approval) get_or_create_human_review = _(get_or_create_human_review) has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec) update_review_processed_status = _(update_review_processed_status) - # Notifications - async + # ============ Notifications ============ # clear_all_user_notification_batches = _(clear_all_user_notification_batches) create_or_add_to_user_notification_batch = _( create_or_add_to_user_notification_batch @@ -212,29 +242,56 @@ class DatabaseManager(AppService): get_user_notification_oldest_message_in_batch ) - # Library + # ============ Library ============ # list_library_agents = _(list_library_agents) add_store_agent_to_library = _(add_store_agent_to_library) + create_graph_in_library = _(create_graph_in_library) + create_library_agent = _(create_library_agent) + get_library_agent = _(get_library_agent) + get_library_agent_by_graph_id = _(get_library_agent_by_graph_id) + update_graph_in_library = _(update_graph_in_library) validate_graph_execution_permissions = _(validate_graph_execution_permissions) - # Onboarding + # ============ Onboarding ============ # increment_onboarding_runs = _(increment_onboarding_runs) - # OAuth + # ============ OAuth ============ # cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens) - # Store + # ============ Store ============ # get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) + get_agent = _(get_agent) + get_available_graph = _(get_available_graph) - # Store Embeddings + # ============ Search ============ # get_embedding_stats = _(get_embedding_stats) backfill_missing_embeddings = _(backfill_missing_embeddings) cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings) + unified_hybrid_search = _(unified_hybrid_search) - # Summary data - async + # ============ Summary Data ============ # get_user_execution_summary_data = _(get_user_execution_summary_data) + # ============ Workspace ============ # + get_or_create_workspace = _(get_or_create_workspace) + + # ============ Understanding ============ # + get_business_understanding = _(get_business_understanding) + upsert_business_understanding = _(upsert_business_understanding) + + # ============ CoPilot Chat Sessions ============ # + get_chat_session = _(chat_db.get_chat_session) + create_chat_session = _(chat_db.create_chat_session) + update_chat_session = _(chat_db.update_chat_session) + add_chat_message = _(chat_db.add_chat_message) + add_chat_messages_batch = _(chat_db.add_chat_messages_batch) + get_user_chat_sessions = _(chat_db.get_user_chat_sessions) + get_user_session_count = _(chat_db.get_user_session_count) + delete_chat_session = _(chat_db.delete_chat_session) + get_chat_session_message_count = _(chat_db.get_chat_session_message_count) + update_tool_message_content = _(chat_db.update_tool_message_content) + class DatabaseManagerClient(AppServiceClient): d = DatabaseManager @@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient): def get_service_type(cls): return DatabaseManager + # ============ Graph Executions ============ # create_graph_execution = d.create_graph_execution get_child_graph_executions = d.get_child_graph_executions get_connected_output_nodes = d.get_connected_output_nodes get_latest_node_execution = d.get_latest_node_execution - get_graph = d.get_graph - get_graph_metadata = d.get_graph_metadata - get_graph_settings = d.get_graph_settings get_graph_execution = d.get_graph_execution get_graph_execution_meta = d.get_graph_execution_meta - get_node = d.get_node + get_graph_executions = d.get_graph_executions get_node_execution = d.get_node_execution get_node_executions = d.get_node_executions - get_user_by_id = d.get_user_by_id - get_user_integrations = d.get_user_integrations - upsert_execution_input = d.upsert_execution_input - upsert_execution_output = d.upsert_execution_output - get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id update_graph_execution_stats = d.update_graph_execution_stats update_node_execution_status = d.update_node_execution_status update_node_execution_status_batch = d.update_node_execution_status_batch - update_user_integrations = d.update_user_integrations + upsert_execution_input = d.upsert_execution_input + upsert_execution_output = d.upsert_execution_output + get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id get_execution_kv_data = d.get_execution_kv_data set_execution_kv_data = d.set_execution_kv_data - # Human In The Loop + # ============ Graphs ============ # + get_graph = d.get_graph + get_graph_metadata = d.get_graph_metadata + get_graph_settings = d.get_graph_settings + get_node = d.get_node + get_store_listed_graphs = d.get_store_listed_graphs + + # ============ User + Integrations ============ # + get_user_by_id = d.get_user_by_id + get_user_integrations = d.get_user_integrations + update_user_integrations = d.update_user_integrations + + # ============ Human In The Loop ============ # cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution check_approval = d.check_approval get_or_create_human_review = d.get_or_create_human_review update_review_processed_status = d.update_review_processed_status - # User Comms + # ============ User Comms ============ # get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange get_user_email_by_id = d.get_user_email_by_id get_user_email_verification = d.get_user_email_verification get_user_notification_preference = d.get_user_notification_preference - # Notifications + # ============ Notifications ============ # clear_all_user_notification_batches = d.clear_all_user_notification_batches create_or_add_to_user_notification_batch = ( d.create_or_add_to_user_notification_batch @@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient): d.get_user_notification_oldest_message_in_batch ) - # Library + # ============ Library ============ # list_library_agents = d.list_library_agents add_store_agent_to_library = d.add_store_agent_to_library + create_graph_in_library = d.create_graph_in_library + create_library_agent = d.create_library_agent + get_library_agent = d.get_library_agent + get_library_agent_by_graph_id = d.get_library_agent_by_graph_id + update_graph_in_library = d.update_graph_in_library validate_graph_execution_permissions = d.validate_graph_execution_permissions - # Onboarding + # ============ Onboarding ============ # increment_onboarding_runs = d.increment_onboarding_runs - # OAuth + # ============ OAuth ============ # cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens - # Store + # ============ Store ============ # get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details + get_agent = d.get_agent + get_available_graph = d.get_available_graph - # Summary data + # ============ Search ============ # + unified_hybrid_search = d.unified_hybrid_search + + # ============ Summary Data ============ # get_user_execution_summary_data = d.get_user_execution_summary_data + + # ============ Workspace ============ # + get_or_create_workspace = d.get_or_create_workspace + + # ============ Understanding ============ # + get_business_understanding = d.get_business_understanding + upsert_business_understanding = d.upsert_business_understanding + + # ============ CoPilot Chat Sessions ============ # + get_chat_session = d.get_chat_session + create_chat_session = d.create_chat_session + update_chat_session = d.update_chat_session + add_chat_message = d.add_chat_message + add_chat_messages_batch = d.add_chat_messages_batch + get_user_chat_sessions = d.get_user_chat_sessions + get_user_session_count = d.get_user_session_count + delete_chat_session = d.delete_chat_session + get_chat_session_message_count = d.get_chat_session_message_count + update_tool_message_content = d.update_tool_message_content diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index 94f99852e8..a9a92dcdd5 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -1147,14 +1147,14 @@ async def get_graph( return GraphModel.from_db(graph, for_export) -async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]: +async def get_store_listed_graphs(graph_ids: list[str]) -> dict[str, GraphModel]: """Batch-fetch multiple store-listed graphs by their IDs. Only returns graphs that have approved store listings (publicly available). Does not require permission checks since store-listed graphs are public. Args: - *graph_ids: Variable number of graph IDs to fetch + graph_ids: List of graph IDs to fetch Returns: Dict mapping graph_id to GraphModel for graphs with approved store listings diff --git a/autogpt_platform/backend/backend/data/workspace.py b/autogpt_platform/backend/backend/data/workspace.py index f3dba0a294..fdf378747d 100644 --- a/autogpt_platform/backend/backend/data/workspace.py +++ b/autogpt_platform/backend/backend/data/workspace.py @@ -8,6 +8,7 @@ import logging from datetime import datetime, timezone from typing import Optional +import pydantic from prisma.models import UserWorkspace, UserWorkspaceFile from prisma.types import UserWorkspaceFileWhereInput @@ -16,7 +17,61 @@ from backend.util.json import SafeJson logger = logging.getLogger(__name__) -async def get_or_create_workspace(user_id: str) -> UserWorkspace: +class Workspace(pydantic.BaseModel): + """Pydantic model for UserWorkspace, safe for RPC transport.""" + + id: str + user_id: str + created_at: datetime + updated_at: datetime + + @staticmethod + def from_db(workspace: "UserWorkspace") -> "Workspace": + return Workspace( + id=workspace.id, + user_id=workspace.userId, + created_at=workspace.createdAt, + updated_at=workspace.updatedAt, + ) + + +class WorkspaceFile(pydantic.BaseModel): + """Pydantic model for UserWorkspaceFile, safe for RPC transport.""" + + id: str + workspace_id: str + created_at: datetime + updated_at: datetime + name: str + path: str + storage_path: str + mime_type: str + size_bytes: int + checksum: Optional[str] = None + is_deleted: bool = False + deleted_at: Optional[datetime] = None + metadata: dict = pydantic.Field(default_factory=dict) + + @staticmethod + def from_db(file: "UserWorkspaceFile") -> "WorkspaceFile": + return WorkspaceFile( + id=file.id, + workspace_id=file.workspaceId, + created_at=file.createdAt, + updated_at=file.updatedAt, + name=file.name, + path=file.path, + storage_path=file.storagePath, + mime_type=file.mimeType, + size_bytes=file.sizeBytes, + checksum=file.checksum, + is_deleted=file.isDeleted, + deleted_at=file.deletedAt, + metadata=file.metadata if isinstance(file.metadata, dict) else {}, + ) + + +async def get_or_create_workspace(user_id: str) -> Workspace: """ Get user's workspace, creating one if it doesn't exist. @@ -27,7 +82,7 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace: user_id: The user's ID Returns: - UserWorkspace instance + Workspace instance """ workspace = await UserWorkspace.prisma().upsert( where={"userId": user_id}, @@ -37,10 +92,10 @@ async def get_or_create_workspace(user_id: str) -> UserWorkspace: }, ) - return workspace + return Workspace.from_db(workspace) -async def get_workspace(user_id: str) -> Optional[UserWorkspace]: +async def get_workspace(user_id: str) -> Optional[Workspace]: """ Get user's workspace if it exists. @@ -48,9 +103,10 @@ async def get_workspace(user_id: str) -> Optional[UserWorkspace]: user_id: The user's ID Returns: - UserWorkspace instance or None + Workspace instance or None """ - return await UserWorkspace.prisma().find_unique(where={"userId": user_id}) + workspace = await UserWorkspace.prisma().find_unique(where={"userId": user_id}) + return Workspace.from_db(workspace) if workspace else None async def create_workspace_file( @@ -63,7 +119,7 @@ async def create_workspace_file( size_bytes: int, checksum: Optional[str] = None, metadata: Optional[dict] = None, -) -> UserWorkspaceFile: +) -> WorkspaceFile: """ Create a new workspace file record. @@ -79,7 +135,7 @@ async def create_workspace_file( metadata: Optional additional metadata Returns: - Created UserWorkspaceFile instance + Created WorkspaceFile instance """ # Normalize path to start with / if not path.startswith("/"): @@ -103,13 +159,13 @@ async def create_workspace_file( f"Created workspace file {file.id} at path {path} " f"in workspace {workspace_id}" ) - return file + return WorkspaceFile.from_db(file) async def get_workspace_file( file_id: str, workspace_id: Optional[str] = None, -) -> Optional[UserWorkspaceFile]: +) -> Optional[WorkspaceFile]: """ Get a workspace file by ID. @@ -118,19 +174,20 @@ async def get_workspace_file( workspace_id: Optional workspace ID for validation Returns: - UserWorkspaceFile instance or None + WorkspaceFile instance or None """ where_clause: dict = {"id": file_id, "isDeleted": False} if workspace_id: where_clause["workspaceId"] = workspace_id - return await UserWorkspaceFile.prisma().find_first(where=where_clause) + file = await UserWorkspaceFile.prisma().find_first(where=where_clause) + return WorkspaceFile.from_db(file) if file else None async def get_workspace_file_by_path( workspace_id: str, path: str, -) -> Optional[UserWorkspaceFile]: +) -> Optional[WorkspaceFile]: """ Get a workspace file by its virtual path. @@ -139,19 +196,20 @@ async def get_workspace_file_by_path( path: Virtual path Returns: - UserWorkspaceFile instance or None + WorkspaceFile instance or None """ # Normalize path if not path.startswith("/"): path = f"/{path}" - return await UserWorkspaceFile.prisma().find_first( + file = await UserWorkspaceFile.prisma().find_first( where={ "workspaceId": workspace_id, "path": path, "isDeleted": False, } ) + return WorkspaceFile.from_db(file) if file else None async def list_workspace_files( @@ -160,7 +218,7 @@ async def list_workspace_files( include_deleted: bool = False, limit: Optional[int] = None, offset: int = 0, -) -> list[UserWorkspaceFile]: +) -> list[WorkspaceFile]: """ List files in a workspace. @@ -172,7 +230,7 @@ async def list_workspace_files( offset: Number of files to skip Returns: - List of UserWorkspaceFile instances + List of WorkspaceFile instances """ where_clause: UserWorkspaceFileWhereInput = {"workspaceId": workspace_id} @@ -185,12 +243,13 @@ async def list_workspace_files( path_prefix = f"/{path_prefix}" where_clause["path"] = {"startswith": path_prefix} - return await UserWorkspaceFile.prisma().find_many( + files = await UserWorkspaceFile.prisma().find_many( where=where_clause, order={"createdAt": "desc"}, take=limit, skip=offset, ) + return [WorkspaceFile.from_db(f) for f in files] async def count_workspace_files( @@ -225,7 +284,7 @@ async def count_workspace_files( async def soft_delete_workspace_file( file_id: str, workspace_id: Optional[str] = None, -) -> Optional[UserWorkspaceFile]: +) -> Optional[WorkspaceFile]: """ Soft-delete a workspace file. @@ -237,7 +296,7 @@ async def soft_delete_workspace_file( workspace_id: Optional workspace ID for validation Returns: - Updated UserWorkspaceFile instance or None if not found + Updated WorkspaceFile instance or None if not found """ # First verify the file exists and belongs to workspace file = await get_workspace_file(file_id, workspace_id) @@ -259,7 +318,7 @@ async def soft_delete_workspace_file( ) logger.info(f"Soft-deleted workspace file {file_id}") - return updated + return WorkspaceFile.from_db(updated) if updated else None async def get_workspace_total_size(workspace_id: str) -> int: @@ -273,4 +332,4 @@ async def get_workspace_total_size(workspace_id: str) -> int: Total size in bytes """ files = await list_workspace_files(workspace_id) - return sum(file.sizeBytes for file in files) + return sum(file.size_bytes for file in files) diff --git a/autogpt_platform/backend/backend/db.py b/autogpt_platform/backend/backend/db.py index 5c59a98a00..2661405f6d 100644 --- a/autogpt_platform/backend/backend/db.py +++ b/autogpt_platform/backend/backend/db.py @@ -1,5 +1,5 @@ from backend.app import run_processes -from backend.executor import DatabaseManager +from backend.data.db_manager import DatabaseManager def main(): diff --git a/autogpt_platform/backend/backend/executor/__init__.py b/autogpt_platform/backend/backend/executor/__init__.py index 92d8b5dc58..883bb226e6 100644 --- a/autogpt_platform/backend/backend/executor/__init__.py +++ b/autogpt_platform/backend/backend/executor/__init__.py @@ -1,11 +1,7 @@ -from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient from .manager import ExecutionManager from .scheduler import Scheduler __all__ = [ - "DatabaseManager", - "DatabaseManagerClient", - "DatabaseManagerAsyncClient", "ExecutionManager", "Scheduler", ] diff --git a/autogpt_platform/backend/backend/executor/activity_status_generator.py b/autogpt_platform/backend/backend/executor/activity_status_generator.py index 8cc1da8957..dd5d7ddfc4 100644 --- a/autogpt_platform/backend/backend/executor/activity_status_generator.py +++ b/autogpt_platform/backend/backend/executor/activity_status_generator.py @@ -22,7 +22,7 @@ from backend.util.settings import Settings from backend.util.truncate import truncate if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/executor/automod/manager.py b/autogpt_platform/backend/backend/executor/automod/manager.py index 81001196dd..2eef4f6eca 100644 --- a/autogpt_platform/backend/backend/executor/automod/manager.py +++ b/autogpt_platform/backend/backend/executor/automod/manager.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient from pydantic import ValidationError diff --git a/autogpt_platform/backend/backend/executor/cluster_lock.py b/autogpt_platform/backend/backend/executor/cluster_lock.py index ad6362b535..c049a3c8ff 100644 --- a/autogpt_platform/backend/backend/executor/cluster_lock.py +++ b/autogpt_platform/backend/backend/executor/cluster_lock.py @@ -1,6 +1,7 @@ """Redis-based distributed locking for cluster coordination.""" import logging +import threading import time from typing import TYPE_CHECKING @@ -19,6 +20,7 @@ class ClusterLock: self.owner_id = owner_id self.timeout = timeout self._last_refresh = 0.0 + self._refresh_lock = threading.Lock() def try_acquire(self) -> str | None: """Try to acquire the lock. @@ -31,7 +33,8 @@ class ClusterLock: try: success = self.redis.set(self.key, self.owner_id, nx=True, ex=self.timeout) if success: - self._last_refresh = time.time() + with self._refresh_lock: + self._last_refresh = time.time() return self.owner_id # Successfully acquired # Failed to acquire, get current owner @@ -57,23 +60,27 @@ class ClusterLock: Rate limited to at most once every timeout/10 seconds (minimum 1 second). During rate limiting, still verifies lock existence but skips TTL extension. Setting _last_refresh to 0 bypasses rate limiting for testing. + + Thread-safe: uses _refresh_lock to protect _last_refresh access. """ # Calculate refresh interval: max(timeout // 10, 1) refresh_interval = max(self.timeout // 10, 1) current_time = time.time() - # Check if we're within the rate limit period + # Check if we're within the rate limit period (thread-safe read) # _last_refresh == 0 forces a refresh (bypasses rate limiting for testing) + with self._refresh_lock: + last_refresh = self._last_refresh is_rate_limited = ( - self._last_refresh > 0 - and (current_time - self._last_refresh) < refresh_interval + last_refresh > 0 and (current_time - last_refresh) < refresh_interval ) try: # Always verify lock existence, even during rate limiting current_value = self.redis.get(self.key) if not current_value: - self._last_refresh = 0 + with self._refresh_lock: + self._last_refresh = 0 return False stored_owner = ( @@ -82,7 +89,8 @@ class ClusterLock: else str(current_value) ) if stored_owner != self.owner_id: - self._last_refresh = 0 + with self._refresh_lock: + self._last_refresh = 0 return False # If rate limited, return True but don't update TTL or timestamp @@ -91,25 +99,30 @@ class ClusterLock: # Perform actual refresh if self.redis.expire(self.key, self.timeout): - self._last_refresh = current_time + with self._refresh_lock: + self._last_refresh = current_time return True - self._last_refresh = 0 + with self._refresh_lock: + self._last_refresh = 0 return False except Exception as e: logger.error(f"ClusterLock.refresh failed for key {self.key}: {e}") - self._last_refresh = 0 + with self._refresh_lock: + self._last_refresh = 0 return False def release(self): """Release the lock.""" - if self._last_refresh == 0: - return + with self._refresh_lock: + if self._last_refresh == 0: + return try: self.redis.delete(self.key) except Exception: pass - self._last_refresh = 0.0 + with self._refresh_lock: + self._last_refresh = 0.0 diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index caa98784c2..4444e15d22 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -93,7 +93,10 @@ from .utils import ( ) if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient + from backend.data.db_manager import ( + DatabaseManagerAsyncClient, + DatabaseManagerClient, + ) _logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/util/clients.py b/autogpt_platform/backend/backend/util/clients.py index 570e9fa3de..ed884e9e6c 100644 --- a/autogpt_platform/backend/backend/util/clients.py +++ b/autogpt_platform/backend/backend/util/clients.py @@ -13,12 +13,15 @@ if TYPE_CHECKING: from openai import AsyncOpenAI from supabase import AClient, Client + from backend.data.db_manager import ( + DatabaseManagerAsyncClient, + DatabaseManagerClient, + ) from backend.data.execution import ( AsyncRedisExecutionEventBus, RedisExecutionEventBus, ) from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ - from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient from backend.executor.scheduler import SchedulerClient from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.notifications.notifications import NotificationManagerClient @@ -27,7 +30,7 @@ if TYPE_CHECKING: @thread_cached def get_database_manager_client() -> "DatabaseManagerClient": """Get a thread-cached DatabaseManagerClient with request retry enabled.""" - from backend.executor import DatabaseManagerClient + from backend.data.db_manager import DatabaseManagerClient from backend.util.service import get_service_client return get_service_client(DatabaseManagerClient, request_retry=True) @@ -38,7 +41,7 @@ def get_database_manager_async_client( should_retry: bool = True, ) -> "DatabaseManagerAsyncClient": """Get a thread-cached DatabaseManagerAsyncClient with request retry enabled.""" - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient from backend.util.service import get_service_client return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry) @@ -106,6 +109,20 @@ async def get_async_execution_queue() -> "AsyncRabbitMQ": return client +# ============ CoPilot Queue Helpers ============ # + + +@thread_cached +async def get_async_copilot_queue() -> "AsyncRabbitMQ": + """Get a thread-cached AsyncRabbitMQ CoPilot queue client.""" + from backend.copilot.executor.utils import create_copilot_queue_config + from backend.data.rabbitmq import AsyncRabbitMQ + + client = AsyncRabbitMQ(create_copilot_queue_config()) + await client.connect() + return client + + # ============ Integration Credentials Store ============ # diff --git a/autogpt_platform/backend/backend/util/file.py b/autogpt_platform/backend/backend/util/file.py index 70e354a29c..a30b76b33e 100644 --- a/autogpt_platform/backend/backend/util/file.py +++ b/autogpt_platform/backend/backend/util/file.py @@ -383,7 +383,7 @@ async def store_media_file( else: info = await workspace_manager.get_file_info(ws.file_ref) if info: - return MediaFileType(f"{file}#{info.mimeType}") + return MediaFileType(f"{file}#{info.mime_type}") except Exception: pass return MediaFileType(file) @@ -397,7 +397,7 @@ async def store_media_file( filename=filename, overwrite=True, ) - return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}") + return MediaFileType(f"workspace://{file_record.id}#{file_record.mime_type}") else: raise ValueError(f"Invalid return_format: {return_format}") diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index 20ee10f0c3..1bd4044709 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -211,16 +211,23 @@ class Config(UpdateTrackingModel["Config"], BaseSettings): description="The port for execution manager daemon to run on", ) + num_copilot_workers: int = Field( + default=5, + ge=1, + le=100, + description="Number of concurrent CoPilot executor workers", + ) + + copilot_executor_port: int = Field( + default=8008, + description="The port for CoPilot executor daemon to run on", + ) + execution_scheduler_port: int = Field( default=8003, description="The port for execution scheduler daemon to run on", ) - agent_server_port: int = Field( - default=8004, - description="The port for agent server daemon to run on", - ) - database_api_port: int = Field( default=8005, description="The port for database server API to run on", diff --git a/autogpt_platform/backend/backend/util/test.py b/autogpt_platform/backend/backend/util/test.py index 279b3142a4..26b255e019 100644 --- a/autogpt_platform/backend/backend/util/test.py +++ b/autogpt_platform/backend/backend/util/test.py @@ -11,6 +11,7 @@ from backend.api.rest_api import AgentServer from backend.blocks._base import Block, BlockSchema from backend.data import db from backend.data.block import initialize_blocks +from backend.data.db_manager import DatabaseManager from backend.data.execution import ( ExecutionContext, ExecutionStatus, @@ -19,7 +20,7 @@ from backend.data.execution import ( ) from backend.data.model import _BaseCredentials from backend.data.user import create_default_user -from backend.executor import DatabaseManager, ExecutionManager, Scheduler +from backend.executor import ExecutionManager, Scheduler from backend.notifications.notifications import NotificationManager log = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/util/workspace.py b/autogpt_platform/backend/backend/util/workspace.py index 86413b640a..453f4e9730 100644 --- a/autogpt_platform/backend/backend/util/workspace.py +++ b/autogpt_platform/backend/backend/util/workspace.py @@ -11,9 +11,9 @@ import uuid from typing import Optional from prisma.errors import UniqueViolationError -from prisma.models import UserWorkspaceFile from backend.data.workspace import ( + WorkspaceFile, count_workspace_files, create_workspace_file, get_workspace_file, @@ -131,7 +131,7 @@ class WorkspaceManager: raise FileNotFoundError(f"File not found at path: {resolved_path}") storage = await get_workspace_storage() - return await storage.retrieve(file.storagePath) + return await storage.retrieve(file.storage_path) async def read_file_by_id(self, file_id: str) -> bytes: """ @@ -151,7 +151,7 @@ class WorkspaceManager: raise FileNotFoundError(f"File not found: {file_id}") storage = await get_workspace_storage() - return await storage.retrieve(file.storagePath) + return await storage.retrieve(file.storage_path) async def write_file( self, @@ -160,7 +160,7 @@ class WorkspaceManager: path: Optional[str] = None, mime_type: Optional[str] = None, overwrite: bool = False, - ) -> UserWorkspaceFile: + ) -> WorkspaceFile: """ Write file to workspace. @@ -175,7 +175,7 @@ class WorkspaceManager: overwrite: Whether to overwrite existing file at path Returns: - Created UserWorkspaceFile instance + Created WorkspaceFile instance Raises: ValueError: If file exceeds size limit or path already exists @@ -296,7 +296,7 @@ class WorkspaceManager: limit: Optional[int] = None, offset: int = 0, include_all_sessions: bool = False, - ) -> list[UserWorkspaceFile]: + ) -> list[WorkspaceFile]: """ List files in workspace. @@ -311,7 +311,7 @@ class WorkspaceManager: If False (default), only list current session's files. Returns: - List of UserWorkspaceFile instances + List of WorkspaceFile instances """ effective_path = self._get_effective_path(path, include_all_sessions) @@ -339,7 +339,7 @@ class WorkspaceManager: # Delete from storage storage = await get_workspace_storage() try: - await storage.delete(file.storagePath) + await storage.delete(file.storage_path) except Exception as e: logger.warning(f"Failed to delete file from storage: {e}") # Continue with database soft-delete even if storage delete fails @@ -367,9 +367,9 @@ class WorkspaceManager: raise FileNotFoundError(f"File not found: {file_id}") storage = await get_workspace_storage() - return await storage.get_download_url(file.storagePath, expires_in) + return await storage.get_download_url(file.storage_path, expires_in) - async def get_file_info(self, file_id: str) -> Optional[UserWorkspaceFile]: + async def get_file_info(self, file_id: str) -> Optional[WorkspaceFile]: """ Get file metadata. @@ -377,11 +377,11 @@ class WorkspaceManager: file_id: The file's ID Returns: - UserWorkspaceFile instance or None + WorkspaceFile instance or None """ return await get_workspace_file(file_id, self.workspace_id) - async def get_file_info_by_path(self, path: str) -> Optional[UserWorkspaceFile]: + async def get_file_info_by_path(self, path: str) -> Optional[WorkspaceFile]: """ Get file metadata by path. @@ -392,7 +392,7 @@ class WorkspaceManager: path: Virtual path Returns: - UserWorkspaceFile instance or None + WorkspaceFile instance or None """ resolved_path = self._resolve_path(path) return await get_workspace_file_by_path(self.workspace_id, resolved_path) diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index 7a112e75ca..6467d15f49 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -117,6 +117,7 @@ ws = "backend.ws:main" scheduler = "backend.scheduler:main" notification = "backend.notification:main" executor = "backend.exec:main" +copilot-executor = "backend.copilot.executor.__main__:main" cli = "backend.cli:main" format = "linter:format" lint = "linter:lint" diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index 528763e751..74cb890b37 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -9,10 +9,8 @@ from unittest.mock import AsyncMock, patch import pytest -from backend.api.features.chat.tools.agent_generator import core -from backend.api.features.chat.tools.agent_generator.core import ( - AgentGeneratorNotConfiguredError, -) +from backend.copilot.tools.agent_generator import core +from backend.copilot.tools.agent_generator.core import AgentGeneratorNotConfiguredError class TestServiceNotConfigured: diff --git a/autogpt_platform/backend/test/agent_generator/test_library_agents.py b/autogpt_platform/backend/test/agent_generator/test_library_agents.py index 8387339582..508146af6a 100644 --- a/autogpt_platform/backend/test/agent_generator/test_library_agents.py +++ b/autogpt_platform/backend/test/agent_generator/test_library_agents.py @@ -9,7 +9,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from backend.api.features.chat.tools.agent_generator import core +from backend.copilot.tools.agent_generator import core class TestGetLibraryAgentsForGeneration: @@ -31,18 +31,20 @@ class TestGetLibraryAgentsForGeneration: mock_response = MagicMock() mock_response.agents = [mock_agent] + mock_db = MagicMock() + mock_db.list_library_agents = AsyncMock(return_value=mock_response) + with patch.object( - core.library_db, - "list_library_agents", - new_callable=AsyncMock, - return_value=mock_response, - ) as mock_list: + core, + "library_db", + return_value=mock_db, + ): result = await core.get_library_agents_for_generation( user_id="user-123", search_query="send email", ) - mock_list.assert_called_once_with( + mock_db.list_library_agents.assert_called_once_with( user_id="user-123", search_term="send email", page=1, @@ -80,11 +82,13 @@ class TestGetLibraryAgentsForGeneration: ), ] + mock_db = MagicMock() + mock_db.list_library_agents = AsyncMock(return_value=mock_response) + with patch.object( - core.library_db, - "list_library_agents", - new_callable=AsyncMock, - return_value=mock_response, + core, + "library_db", + return_value=mock_db, ): result = await core.get_library_agents_for_generation( user_id="user-123", @@ -101,18 +105,20 @@ class TestGetLibraryAgentsForGeneration: mock_response = MagicMock() mock_response.agents = [] + mock_db = MagicMock() + mock_db.list_library_agents = AsyncMock(return_value=mock_response) + with patch.object( - core.library_db, - "list_library_agents", - new_callable=AsyncMock, - return_value=mock_response, - ) as mock_list: + core, + "library_db", + return_value=mock_db, + ): await core.get_library_agents_for_generation( user_id="user-123", max_results=5, ) - mock_list.assert_called_once_with( + mock_db.list_library_agents.assert_called_once_with( user_id="user-123", search_term=None, page=1, @@ -144,24 +150,24 @@ class TestSearchMarketplaceAgentsForGeneration: mock_graph.input_schema = {"type": "object"} mock_graph.output_schema = {"type": "object"} + mock_store_db = MagicMock() + mock_store_db.get_store_agents = AsyncMock(return_value=mock_response) + + mock_graph_db = MagicMock() + mock_graph_db.get_store_listed_graphs = AsyncMock( + return_value={"graph-123": mock_graph} + ) + with ( - patch( - "backend.api.features.store.db.get_store_agents", - new_callable=AsyncMock, - return_value=mock_response, - ) as mock_search, - patch( - "backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs", - new_callable=AsyncMock, - return_value={"graph-123": mock_graph}, - ), + patch.object(core, "store_db", return_value=mock_store_db), + patch.object(core, "graph_db", return_value=mock_graph_db), ): result = await core.search_marketplace_agents_for_generation( search_query="automation", max_results=10, ) - mock_search.assert_called_once_with( + mock_store_db.get_store_agents.assert_called_once_with( search_query="automation", page=1, page_size=10, @@ -707,7 +713,7 @@ class TestExtractUuidsFromText: class TestGetLibraryAgentById: - """Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id).""" + """Test get_library_agent_by_id function (alias: get_library_agent_by_graph_id).""" @pytest.mark.asyncio async def test_returns_agent_when_found_by_graph_id(self): @@ -720,12 +726,10 @@ class TestGetLibraryAgentById: mock_agent.input_schema = {"properties": {}} mock_agent.output_schema = {"properties": {}} - with patch.object( - core.library_db, - "get_library_agent_by_graph_id", - new_callable=AsyncMock, - return_value=mock_agent, - ): + mock_db = MagicMock() + mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent) + + with patch.object(core, "library_db", return_value=mock_db): result = await core.get_library_agent_by_id("user-123", "agent-123") assert result is not None @@ -743,20 +747,11 @@ class TestGetLibraryAgentById: mock_agent.input_schema = {"properties": {}} mock_agent.output_schema = {"properties": {}} - with ( - patch.object( - core.library_db, - "get_library_agent_by_graph_id", - new_callable=AsyncMock, - return_value=None, # Not found by graph_id - ), - patch.object( - core.library_db, - "get_library_agent", - new_callable=AsyncMock, - return_value=mock_agent, # Found by library ID - ), - ): + mock_db = MagicMock() + mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None) + mock_db.get_library_agent = AsyncMock(return_value=mock_agent) + + with patch.object(core, "library_db", return_value=mock_db): result = await core.get_library_agent_by_id("user-123", "library-id-123") assert result is not None @@ -766,20 +761,13 @@ class TestGetLibraryAgentById: @pytest.mark.asyncio async def test_returns_none_when_not_found_by_either_method(self): """Test that None is returned when agent not found by either method.""" - with ( - patch.object( - core.library_db, - "get_library_agent_by_graph_id", - new_callable=AsyncMock, - return_value=None, - ), - patch.object( - core.library_db, - "get_library_agent", - new_callable=AsyncMock, - side_effect=core.NotFoundError("Not found"), - ), - ): + mock_db = MagicMock() + mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=None) + mock_db.get_library_agent = AsyncMock( + side_effect=core.NotFoundError("Not found") + ) + + with patch.object(core, "library_db", return_value=mock_db): result = await core.get_library_agent_by_id("user-123", "nonexistent") assert result is None @@ -787,27 +775,20 @@ class TestGetLibraryAgentById: @pytest.mark.asyncio async def test_returns_none_on_exception(self): """Test that None is returned when exception occurs in both lookups.""" - with ( - patch.object( - core.library_db, - "get_library_agent_by_graph_id", - new_callable=AsyncMock, - side_effect=Exception("Database error"), - ), - patch.object( - core.library_db, - "get_library_agent", - new_callable=AsyncMock, - side_effect=Exception("Database error"), - ), - ): + mock_db = MagicMock() + mock_db.get_library_agent_by_graph_id = AsyncMock( + side_effect=Exception("Database error") + ) + mock_db.get_library_agent = AsyncMock(side_effect=Exception("Database error")) + + with patch.object(core, "library_db", return_value=mock_db): result = await core.get_library_agent_by_id("user-123", "agent-123") assert result is None @pytest.mark.asyncio async def test_alias_works(self): - """Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id.""" + """Test that get_library_agent_by_graph_id is an alias.""" assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id @@ -828,20 +809,11 @@ class TestGetAllRelevantAgentsWithUuids: mock_response = MagicMock() mock_response.agents = [] - with ( - patch.object( - core.library_db, - "get_library_agent_by_graph_id", - new_callable=AsyncMock, - return_value=mock_agent, - ), - patch.object( - core.library_db, - "list_library_agents", - new_callable=AsyncMock, - return_value=mock_response, - ), - ): + mock_db = MagicMock() + mock_db.get_library_agent_by_graph_id = AsyncMock(return_value=mock_agent) + mock_db.list_library_agents = AsyncMock(return_value=mock_response) + + with patch.object(core, "library_db", return_value=mock_db): result = await core.get_all_relevant_agents_for_generation( user_id="user-123", search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d", diff --git a/autogpt_platform/backend/test/agent_generator/test_service.py b/autogpt_platform/backend/test/agent_generator/test_service.py index 93c9b9dcc0..587a64cbac 100644 --- a/autogpt_platform/backend/test/agent_generator/test_service.py +++ b/autogpt_platform/backend/test/agent_generator/test_service.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import httpx import pytest -from backend.api.features.chat.tools.agent_generator import service +from backend.copilot.tools.agent_generator import service class TestServiceConfiguration: diff --git a/autogpt_platform/backend/test/chat/test_security_hooks.py b/autogpt_platform/backend/test/chat/test_security_hooks.py deleted file mode 100644 index f10a90871b..0000000000 --- a/autogpt_platform/backend/test/chat/test_security_hooks.py +++ /dev/null @@ -1,133 +0,0 @@ -"""Tests for SDK security hooks — workspace paths, tool access, and deny messages. - -These are pure unit tests with no external dependencies (no SDK, no DB, no server). -They validate that the security hooks correctly block unauthorized paths, -tool access, and dangerous input patterns. - -Note: Bash command validation was removed — the SDK built-in Bash tool is not in -allowed_tools, and the bash_exec MCP tool has kernel-level network isolation -(unshare --net) making command-level parsing unnecessary. -""" - -from backend.api.features.chat.sdk.security_hooks import ( - _validate_tool_access, - _validate_workspace_path, -) - -SDK_CWD = "/tmp/copilot-test-session" - - -def _is_denied(result: dict) -> bool: - hook = result.get("hookSpecificOutput", {}) - return hook.get("permissionDecision") == "deny" - - -def _reason(result: dict) -> str: - return result.get("hookSpecificOutput", {}).get("permissionDecisionReason", "") - - -# ============================================================ -# Workspace path validation (Read, Write, Edit, etc.) -# ============================================================ - - -class TestWorkspacePathValidation: - def test_path_in_workspace(self): - result = _validate_workspace_path( - "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD - ) - assert not _is_denied(result) - - def test_path_outside_workspace(self): - result = _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) - assert _is_denied(result) - - def test_tool_results_allowed(self): - result = _validate_workspace_path( - "Read", - {"file_path": "~/.claude/projects/abc/tool-results/out.txt"}, - SDK_CWD, - ) - assert not _is_denied(result) - - def test_claude_settings_blocked(self): - result = _validate_workspace_path( - "Read", {"file_path": "~/.claude/settings.json"}, SDK_CWD - ) - assert _is_denied(result) - - def test_claude_projects_without_tool_results(self): - result = _validate_workspace_path( - "Read", {"file_path": "~/.claude/projects/abc/credentials.json"}, SDK_CWD - ) - assert _is_denied(result) - - def test_no_path_allowed(self): - """Glob/Grep without path defaults to cwd — should be allowed.""" - result = _validate_workspace_path("Grep", {"pattern": "foo"}, SDK_CWD) - assert not _is_denied(result) - - def test_path_traversal_with_dotdot(self): - result = _validate_workspace_path( - "Read", {"file_path": f"{SDK_CWD}/../../../etc/passwd"}, SDK_CWD - ) - assert _is_denied(result) - - -# ============================================================ -# Tool access validation -# ============================================================ - - -class TestToolAccessValidation: - def test_blocked_tools(self): - for tool in ("bash", "shell", "exec", "terminal", "command"): - result = _validate_tool_access(tool, {}) - assert _is_denied(result), f"Tool '{tool}' should be blocked" - - def test_bash_builtin_blocked(self): - """SDK built-in Bash (capital) is blocked as defence-in-depth.""" - result = _validate_tool_access("Bash", {"command": "echo hello"}, SDK_CWD) - assert _is_denied(result) - assert "Bash" in _reason(result) - - def test_workspace_tools_delegate(self): - result = _validate_tool_access( - "Read", {"file_path": f"{SDK_CWD}/file.txt"}, SDK_CWD - ) - assert not _is_denied(result) - - def test_dangerous_pattern_blocked(self): - result = _validate_tool_access("SomeUnknownTool", {"data": "sudo rm -rf /"}) - assert _is_denied(result) - - def test_safe_unknown_tool_allowed(self): - result = _validate_tool_access("SomeSafeTool", {"data": "hello world"}) - assert not _is_denied(result) - - -# ============================================================ -# Deny message quality (ntindle feedback) -# ============================================================ - - -class TestDenyMessageClarity: - """Deny messages must include [SECURITY] and 'cannot be bypassed' - so the model knows the restriction is enforced, not a suggestion.""" - - def test_blocked_tool_message(self): - reason = _reason(_validate_tool_access("bash", {})) - assert "[SECURITY]" in reason - assert "cannot be bypassed" in reason - - def test_bash_builtin_blocked_message(self): - reason = _reason(_validate_tool_access("Bash", {"command": "echo hello"})) - assert "[SECURITY]" in reason - assert "cannot be bypassed" in reason - - def test_workspace_path_message(self): - reason = _reason( - _validate_workspace_path("Read", {"file_path": "/etc/passwd"}, SDK_CWD) - ) - assert "[SECURITY]" in reason - assert "cannot be bypassed" in reason diff --git a/autogpt_platform/docker-compose.platform.yml b/autogpt_platform/docker-compose.platform.yml index a104afa63b..906cbb40a7 100644 --- a/autogpt_platform/docker-compose.platform.yml +++ b/autogpt_platform/docker-compose.platform.yml @@ -157,6 +157,41 @@ services: max-size: "10m" max-file: "3" + copilot_executor: + build: + context: ../ + dockerfile: autogpt_platform/backend/Dockerfile + target: server + command: ["python", "-m", "backend.copilot.executor"] + develop: + watch: + - path: ./ + target: autogpt_platform/backend/ + action: rebuild + depends_on: + redis: + condition: service_healthy + rabbitmq: + condition: service_healthy + db: + condition: service_healthy + migrate: + condition: service_completed_successfully + database_manager: + condition: service_started + <<: *backend-env-files + environment: + <<: *backend-env + ports: + - "8008:8008" + networks: + - app-network + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + websocket_server: build: context: ../ diff --git a/autogpt_platform/docker-compose.yml b/autogpt_platform/docker-compose.yml index 1860252f46..be1ee32f20 100644 --- a/autogpt_platform/docker-compose.yml +++ b/autogpt_platform/docker-compose.yml @@ -53,6 +53,12 @@ services: file: ./docker-compose.platform.yml service: executor + copilot_executor: + <<: *agpt-services + extends: + file: ./docker-compose.platform.yml + service: copilot_executor + websocket_server: <<: *agpt-services extends: @@ -174,5 +180,6 @@ services: - deps - rest_server - executor + - copilot_executor - websocket_server - database_manager