From ee2805d14c1a96dd8f5d0c6c0e45f18a041c2b2e Mon Sep 17 00:00:00 2001 From: Reinier van der Leer Date: Wed, 11 Feb 2026 13:43:58 +0100 Subject: [PATCH] fix(backend/copilot): Use `DatabaseManager` where needed --- autogpt_platform/backend/backend/app.py | 3 +- .../backend/backend/copilot/db.py | 4 +- .../backend/copilot/executor/processor.py | 20 +- .../backend/backend/copilot/model.py | 293 +++++++++--------- .../backend/backend/copilot/service.py | 4 +- .../copilot/tools/add_understanding.py | 10 +- .../copilot/tools/agent_generator/core.py | 29 +- .../backend/copilot/tools/agent_output.py | 25 +- .../backend/copilot/tools/agent_search.py | 13 +- .../backend/copilot/tools/customize_agent.py | 4 +- .../backend/copilot/tools/find_block.py | 4 +- .../backend/copilot/tools/run_agent.py | 11 +- .../backend/copilot/tools/run_block.py | 4 +- .../backend/copilot/tools/search_docs.py | 4 +- .../backend/backend/copilot/tools/utils.py | 12 +- .../backend/copilot/tools/workspace_files.py | 10 +- .../backend/backend/data/db_accessors.py | 118 +++++++ .../database.py => data/db_manager.py} | 159 ++++++++-- autogpt_platform/backend/backend/db.py | 2 +- .../backend/backend/executor/__init__.py | 4 - .../executor/activity_status_generator.py | 2 +- .../backend/executor/automod/manager.py | 2 +- .../backend/backend/executor/manager.py | 5 +- .../backend/backend/util/clients.py | 9 +- autogpt_platform/backend/backend/util/test.py | 3 +- 25 files changed, 486 insertions(+), 268 deletions(-) create mode 100644 autogpt_platform/backend/backend/data/db_accessors.py rename autogpt_platform/backend/backend/{executor/database.py => data/db_manager.py} (72%) diff --git a/autogpt_platform/backend/backend/app.py b/autogpt_platform/backend/backend/app.py index d3abd80b12..90a218d2e5 100644 --- a/autogpt_platform/backend/backend/app.py +++ b/autogpt_platform/backend/backend/app.py @@ -39,7 +39,8 @@ def main(**kwargs): from backend.api.rest_api import AgentServer from backend.api.ws_api import WebsocketServer from backend.copilot.executor.manager import CoPilotExecutor - from backend.executor import DatabaseManager, ExecutionManager, Scheduler + from backend.data.db_manager import DatabaseManager + from backend.executor import ExecutionManager, Scheduler from backend.notifications import NotificationManager run_processes( diff --git a/autogpt_platform/backend/backend/copilot/db.py b/autogpt_platform/backend/backend/copilot/db.py index 303ea0a698..f94d959f05 100644 --- a/autogpt_platform/backend/backend/copilot/db.py +++ b/autogpt_platform/backend/backend/copilot/db.py @@ -14,7 +14,7 @@ from prisma.types import ( ChatSessionWhereInput, ) -from backend.data.db import transaction +from backend.data import db from backend.util.json import SafeJson logger = logging.getLogger(__name__) @@ -147,7 +147,7 @@ async def add_chat_messages_batch( created_messages = [] - async with transaction() as tx: + async with db.transaction() as tx: for i, msg in enumerate(messages): # Build input dict dynamically rather than using ChatMessageCreateInput # directly because Prisma's TypedDict validation rejects optional fields diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index 85f555672a..f8040df403 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -79,7 +79,10 @@ class CoPilotProcessor: """Initialize the processor when the worker thread starts. This method is called once per worker thread to set up the async event - loop, connect to Prisma, and initialize any required resources. + loop and initialize any required resources. + + Database is accessed only through DatabaseManager, so we don't need to connect + to Prisma directly. """ configure_logging() set_service_name("CoPilotExecutor") @@ -90,23 +93,8 @@ class CoPilotProcessor: ) self.execution_thread.start() - # Connect to Prisma in the worker's event loop - # This is required because the CoPilot service uses Prisma directly - # TODO: Use DatabaseManager, avoid direct Prisma connection(?) - asyncio.run_coroutine_threadsafe( - self._connect_prisma(), self.execution_loop - ).result(timeout=30.0) - logger.info(f"[CoPilotExecutor] Worker {self.tid} started") - async def _connect_prisma(self): - """Connect to Prisma database in the worker's event loop.""" - from backend.data import db - - if not db.is_connected(): - await db.connect() - logger.info(f"[CoPilotExecutor] Worker {self.tid} connected to Prisma") - @error_logged(swallow=False) def execute( self, diff --git a/autogpt_platform/backend/backend/copilot/model.py b/autogpt_platform/backend/backend/copilot/model.py index 7318ef88d7..baeef7a145 100644 --- a/autogpt_platform/backend/backend/copilot/model.py +++ b/autogpt_platform/backend/backend/copilot/model.py @@ -23,26 +23,17 @@ from prisma.models import ChatMessage as PrismaChatMessage from prisma.models import ChatSession as PrismaChatSession from pydantic import BaseModel +from backend.data.db_accessors import chat_db from backend.data.redis_client import get_redis_async from backend.util import json from backend.util.exceptions import DatabaseError, RedisError -from . import db as chat_db from .config import ChatConfig logger = logging.getLogger(__name__) config = ChatConfig() -def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any: - """Parse a JSON field that may be stored as string or already parsed.""" - if value is None: - return default - if isinstance(value, str): - return json.loads(value) - return value - - # Redis cache key prefix for chat sessions CHAT_SESSION_CACHE_PREFIX = "chat:session:" @@ -52,28 +43,7 @@ def _get_session_cache_key(session_id: str) -> str: return f"{CHAT_SESSION_CACHE_PREFIX}{session_id}" -# Session-level locks to prevent race conditions during concurrent upserts. -# Uses WeakValueDictionary to automatically garbage collect locks when no longer referenced, -# preventing unbounded memory growth while maintaining lock semantics for active sessions. -# Invalidation: Locks are auto-removed by GC when no coroutine holds a reference (after -# async with lock: completes). Explicit cleanup also occurs in delete_chat_session(). -_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() -_session_locks_mutex = asyncio.Lock() - - -async def _get_session_lock(session_id: str) -> asyncio.Lock: - """Get or create a lock for a specific session to prevent concurrent upserts. - - Uses WeakValueDictionary for automatic cleanup: locks are garbage collected - when no coroutine holds a reference to them, preventing memory leaks from - unbounded growth of session locks. - """ - async with _session_locks_mutex: - lock = _session_locks.get(session_id) - if lock is None: - lock = asyncio.Lock() - _session_locks[session_id] = lock - return lock +# ===================== Chat data models ===================== # class ChatMessage(BaseModel): @@ -261,38 +231,26 @@ class ChatSession(BaseModel): return messages -async def _get_session_from_cache(session_id: str) -> ChatSession | None: - """Get a chat session from Redis cache.""" - redis_key = _get_session_cache_key(session_id) - async_redis = await get_redis_async() - raw_session: bytes | None = await async_redis.get(redis_key) - - if raw_session is None: - return None - - try: - session = ChatSession.model_validate_json(raw_session) - logger.info( - f"Loading session {session_id} from cache: " - f"message_count={len(session.messages)}, " - f"roles={[m.role for m in session.messages]}" - ) - return session - except Exception as e: - logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True) - raise RedisError(f"Corrupted session data for {session_id}") from e +def _parse_json_field(value: str | dict | list | None, default: Any = None) -> Any: + """Parse a JSON field that may be stored as string or already parsed.""" + if value is None: + return default + if isinstance(value, str): + return json.loads(value) + return value -async def _cache_session(session: ChatSession) -> None: - """Cache a chat session in Redis.""" - redis_key = _get_session_cache_key(session.session_id) - async_redis = await get_redis_async() - await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) +# ================ Chat cache + DB operations ================ # + +# NOTE: Database calls are automatically routed through DatabaseManager if Prisma is not +# connected directly. async def cache_chat_session(session: ChatSession) -> None: - """Cache a chat session without persisting to the database.""" - await _cache_session(session) + """Cache a chat session in Redis (without persisting to the database).""" + redis_key = _get_session_cache_key(session.session_id) + async_redis = await get_redis_async() + await async_redis.setex(redis_key, config.session_ttl, session.model_dump_json()) async def invalidate_session_cache(session_id: str) -> None: @@ -310,80 +268,6 @@ async def invalidate_session_cache(session_id: str) -> None: logger.warning(f"Failed to invalidate session cache for {session_id}: {e}") -async def _get_session_from_db(session_id: str) -> ChatSession | None: - """Get a chat session from the database.""" - prisma_session = await chat_db.get_chat_session(session_id) - if not prisma_session: - return None - - messages = prisma_session.Messages - logger.info( - f"Loading session {session_id} from DB: " - f"has_messages={messages is not None}, " - f"message_count={len(messages) if messages else 0}, " - f"roles={[m.role for m in messages] if messages else []}" - ) - - return ChatSession.from_db(prisma_session, messages) - - -async def _save_session_to_db( - session: ChatSession, existing_message_count: int -) -> None: - """Save or update a chat session in the database.""" - # Check if session exists in DB - existing = await chat_db.get_chat_session(session.session_id) - - if not existing: - # Create new session - await chat_db.create_chat_session( - session_id=session.session_id, - user_id=session.user_id, - ) - existing_message_count = 0 - - # Calculate total tokens from usage - total_prompt = sum(u.prompt_tokens for u in session.usage) - total_completion = sum(u.completion_tokens for u in session.usage) - - # Update session metadata - await chat_db.update_chat_session( - session_id=session.session_id, - credentials=session.credentials, - successful_agent_runs=session.successful_agent_runs, - successful_agent_schedules=session.successful_agent_schedules, - total_prompt_tokens=total_prompt, - total_completion_tokens=total_completion, - ) - - # Add new messages (only those after existing count) - new_messages = session.messages[existing_message_count:] - if new_messages: - messages_data = [] - for msg in new_messages: - messages_data.append( - { - "role": msg.role, - "content": msg.content, - "name": msg.name, - "tool_call_id": msg.tool_call_id, - "refusal": msg.refusal, - "tool_calls": msg.tool_calls, - "function_call": msg.function_call, - } - ) - logger.info( - f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " - f"roles={[m['role'] for m in messages_data]}, " - f"start_sequence={existing_message_count}" - ) - await chat_db.add_chat_messages_batch( - session_id=session.session_id, - messages=messages_data, - start_sequence=existing_message_count, - ) - - async def get_chat_session( session_id: str, user_id: str | None = None, @@ -431,7 +315,7 @@ async def get_chat_session( # Cache the session from DB try: - await _cache_session(session) + await cache_chat_session(session) logger.info(f"Cached session {session_id} from database") except Exception as e: logger.warning(f"Failed to cache session {session_id}: {e}") @@ -439,6 +323,45 @@ async def get_chat_session( return session +async def _get_session_from_cache(session_id: str) -> ChatSession | None: + """Get a chat session from Redis cache.""" + redis_key = _get_session_cache_key(session_id) + async_redis = await get_redis_async() + raw_session: bytes | None = await async_redis.get(redis_key) + + if raw_session is None: + return None + + try: + session = ChatSession.model_validate_json(raw_session) + logger.info( + f"Loading session {session_id} from cache: " + f"message_count={len(session.messages)}, " + f"roles={[m.role for m in session.messages]}" + ) + return session + except Exception as e: + logger.error(f"Failed to deserialize session {session_id}: {e}", exc_info=True) + raise RedisError(f"Corrupted session data for {session_id}") from e + + +async def _get_session_from_db(session_id: str) -> ChatSession | None: + """Get a chat session from the database.""" + prisma_session = await chat_db().get_chat_session(session_id) + if not prisma_session: + return None + + messages = prisma_session.Messages + logger.info( + f"Loading session {session_id} from DB: " + f"has_messages={messages is not None}, " + f"message_count={len(messages) if messages else 0}, " + f"roles={[m.role for m in messages] if messages else []}" + ) + + return ChatSession.from_db(prisma_session, messages) + + async def upsert_chat_session( session: ChatSession, ) -> ChatSession: @@ -459,7 +382,7 @@ async def upsert_chat_session( async with lock: # Get existing message count from DB for incremental saves - existing_message_count = await chat_db.get_chat_session_message_count( + existing_message_count = await chat_db().get_chat_session_message_count( session.session_id ) @@ -476,7 +399,7 @@ async def upsert_chat_session( # Save to cache (best-effort, even if DB failed) try: - await _cache_session(session) + await cache_chat_session(session) except Exception as e: # If DB succeeded but cache failed, raise cache error if db_error is None: @@ -497,6 +420,65 @@ async def upsert_chat_session( return session +async def _save_session_to_db( + session: ChatSession, existing_message_count: int +) -> None: + """Save or update a chat session in the database.""" + db = chat_db() + + # Check if session exists in DB + existing = await db.get_chat_session(session.session_id) + + if not existing: + # Create new session + await db.create_chat_session( + session_id=session.session_id, + user_id=session.user_id, + ) + existing_message_count = 0 + + # Calculate total tokens from usage + total_prompt = sum(u.prompt_tokens for u in session.usage) + total_completion = sum(u.completion_tokens for u in session.usage) + + # Update session metadata + await db.update_chat_session( + session_id=session.session_id, + credentials=session.credentials, + successful_agent_runs=session.successful_agent_runs, + successful_agent_schedules=session.successful_agent_schedules, + total_prompt_tokens=total_prompt, + total_completion_tokens=total_completion, + ) + + # Add new messages (only those after existing count) + new_messages = session.messages[existing_message_count:] + if new_messages: + messages_data = [] + for msg in new_messages: + messages_data.append( + { + "role": msg.role, + "content": msg.content, + "name": msg.name, + "tool_call_id": msg.tool_call_id, + "refusal": msg.refusal, + "tool_calls": msg.tool_calls, + "function_call": msg.function_call, + } + ) + logger.info( + f"Saving {len(new_messages)} new messages to DB for session {session.session_id}: " + f"roles={[m['role'] for m in messages_data]}, " + f"start_sequence={existing_message_count}" + ) + await db.add_chat_messages_batch( + session_id=session.session_id, + messages=messages_data, + start_sequence=existing_message_count, + ) + + async def create_chat_session(user_id: str) -> ChatSession: """Create a new chat session and persist it. @@ -509,7 +491,7 @@ async def create_chat_session(user_id: str) -> ChatSession: # Create in database first - fail fast if this fails try: - await chat_db.create_chat_session( + await chat_db().create_chat_session( session_id=session.session_id, user_id=user_id, ) @@ -521,7 +503,7 @@ async def create_chat_session(user_id: str) -> ChatSession: # Cache the session (best-effort optimization, DB is source of truth) try: - await _cache_session(session) + await cache_chat_session(session) except Exception as e: logger.warning(f"Failed to cache new session {session.session_id}: {e}") @@ -539,8 +521,9 @@ async def get_user_sessions( A tuple of (sessions, total_count) where total_count is the overall number of sessions for the user (not just the current page). """ - prisma_sessions = await chat_db.get_user_chat_sessions(user_id, limit, offset) - total_count = await chat_db.get_user_session_count(user_id) + db = chat_db() + prisma_sessions = await db.get_user_chat_sessions(user_id, limit, offset) + total_count = await db.get_user_session_count(user_id) sessions = [] for prisma_session in prisma_sessions: @@ -563,7 +546,7 @@ async def delete_chat_session(session_id: str, user_id: str | None = None) -> bo """ # Delete from database first (with optional user_id validation) # This confirms ownership before invalidating cache - deleted = await chat_db.delete_chat_session(session_id, user_id) + deleted = await chat_db().delete_chat_session(session_id, user_id) if not deleted: return False @@ -598,7 +581,7 @@ async def update_session_title(session_id: str, title: str) -> bool: True if updated successfully, False otherwise. """ try: - result = await chat_db.update_chat_session(session_id=session_id, title=title) + result = await chat_db().update_chat_session(session_id=session_id, title=title) if result is None: logger.warning(f"Session {session_id} not found for title update") return False @@ -615,3 +598,29 @@ async def update_session_title(session_id: str, title: str) -> bool: except Exception as e: logger.error(f"Failed to update title for session {session_id}: {e}") return False + + +# ==================== Chat session locks ==================== # + +_session_locks: WeakValueDictionary[str, asyncio.Lock] = WeakValueDictionary() +_session_locks_mutex = asyncio.Lock() + + +async def _get_session_lock(session_id: str) -> asyncio.Lock: + """Get or create a lock for a specific session to prevent concurrent upserts. + + This was originally added to solve the specific problem of race conditions between + the session title thread and the conversation thread, which always occurs on the + same instance as we prevent rapid request sends on the frontend. + + Uses WeakValueDictionary for automatic cleanup: locks are garbage collected + when no coroutine holds a reference to them, preventing memory leaks from + unbounded growth of session locks. Explicit cleanup also occurs + in `delete_chat_session()`. + """ + async with _session_locks_mutex: + lock = _session_locks.get(session_id) + if lock is None: + lock = asyncio.Lock() + _session_locks[session_id] = lock + return lock diff --git a/autogpt_platform/backend/backend/copilot/service.py b/autogpt_platform/backend/backend/copilot/service.py index 072ea88fd5..7edc580481 100644 --- a/autogpt_platform/backend/backend/copilot/service.py +++ b/autogpt_platform/backend/backend/copilot/service.py @@ -27,6 +27,7 @@ from openai.types.chat import ( ChatCompletionToolParam, ) +from backend.data.db_accessors import chat_db from backend.data.redis_client import get_redis_async from backend.data.understanding import ( format_understanding_for_prompt, @@ -35,7 +36,6 @@ from backend.data.understanding import ( from backend.util.exceptions import NotFoundError from backend.util.settings import AppEnvironment, Settings -from . import db as chat_db from . import stream_registry from .config import ChatConfig from .model import ( @@ -1744,7 +1744,7 @@ async def _update_pending_operation( This is called by background tasks when long-running operations complete. """ # Update the message in database - updated = await chat_db.update_tool_message_content( + updated = await chat_db().update_tool_message_content( session_id=session_id, tool_call_id=tool_call_id, new_content=result, diff --git a/autogpt_platform/backend/backend/copilot/tools/add_understanding.py b/autogpt_platform/backend/backend/copilot/tools/add_understanding.py index 6da759d3cf..b3291c5b0e 100644 --- a/autogpt_platform/backend/backend/copilot/tools/add_understanding.py +++ b/autogpt_platform/backend/backend/copilot/tools/add_understanding.py @@ -4,10 +4,8 @@ import logging from typing import Any from backend.copilot.model import ChatSession -from backend.data.understanding import ( - BusinessUnderstandingInput, - upsert_business_understanding, -) +from backend.data.db_accessors import understanding_db +from backend.data.understanding import BusinessUnderstandingInput from .base import BaseTool from .models import ErrorResponse, ToolResponseBase, UnderstandingUpdatedResponse @@ -99,7 +97,9 @@ and automations for the user's specific needs.""" ] # Upsert with merge - understanding = await upsert_business_understanding(user_id, input_data) + understanding = await understanding_db().upsert_business_understanding( + user_id, input_data + ) # Build current understanding summary (filter out empty values) current_understanding = { diff --git a/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py b/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py index f83ca30b5c..365bcbf244 100644 --- a/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_generator/core.py @@ -5,9 +5,8 @@ import re import uuid from typing import Any, NotRequired, TypedDict -from backend.api.features.library import db as library_db -from backend.api.features.store import db as store_db -from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs +from backend.data.db_accessors import graph_db, library_db, store_db +from backend.data.graph import Graph, Link, Node from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( @@ -145,8 +144,9 @@ async def get_library_agent_by_id( Returns: LibraryAgentSummary if found, None otherwise """ + db = library_db() try: - agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + agent = await db.get_library_agent_by_graph_id(user_id, agent_id) if agent: logger.debug(f"Found library agent by graph_id: {agent.name}") return LibraryAgentSummary( @@ -163,7 +163,7 @@ async def get_library_agent_by_id( logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}") try: - agent = await library_db.get_library_agent(agent_id, user_id) + agent = await db.get_library_agent(agent_id, user_id) if agent: logger.debug(f"Found library agent by library_id: {agent.name}") return LibraryAgentSummary( @@ -215,7 +215,7 @@ async def get_library_agents_for_generation( List of LibraryAgentSummary with schemas and recent executions for sub-agent composition """ try: - response = await library_db.list_library_agents( + response = await library_db().list_library_agents( user_id=user_id, search_term=search_query, page=1, @@ -272,7 +272,7 @@ async def search_marketplace_agents_for_generation( List of LibraryAgentSummary with full input/output schemas """ try: - response = await store_db.get_store_agents( + response = await store_db().get_store_agents( search_query=search_query, page=1, page_size=max_results, @@ -286,7 +286,7 @@ async def search_marketplace_agents_for_generation( return [] graph_ids = [agent.agent_graph_id for agent in agents_with_graphs] - graphs = await get_store_listed_graphs(*graph_ids) + graphs = await graph_db().get_store_listed_graphs(*graph_ids) results: list[LibraryAgentSummary] = [] for agent in agents_with_graphs: @@ -673,9 +673,10 @@ async def save_agent_to_library( Tuple of (created Graph, LibraryAgent) """ graph = json_to_graph(agent_json) + db = library_db() if is_update: - return await library_db.update_graph_in_library(graph, user_id) - return await library_db.create_graph_in_library(graph, user_id) + return await db.update_graph_in_library(graph, user_id) + return await db.create_graph_in_library(graph, user_id) def graph_to_json(graph: Graph) -> dict[str, Any]: @@ -735,12 +736,14 @@ async def get_agent_as_json( Returns: Agent as JSON dict or None if not found """ - graph = await get_graph(agent_id, version=None, user_id=user_id) + db = graph_db() + + graph = await db.get_graph(agent_id, version=None, user_id=user_id) if not graph and user_id: try: - library_agent = await library_db.get_library_agent(agent_id, user_id) - graph = await get_graph( + library_agent = await library_db().get_library_agent(agent_id, user_id) + graph = await db.get_graph( library_agent.graph_id, version=None, user_id=user_id ) except NotFoundError: diff --git a/autogpt_platform/backend/backend/copilot/tools/agent_output.py b/autogpt_platform/backend/backend/copilot/tools/agent_output.py index 96491b749a..fe4767d09e 100644 --- a/autogpt_platform/backend/backend/copilot/tools/agent_output.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_output.py @@ -7,10 +7,9 @@ from typing import Any from pydantic import BaseModel, field_validator -from backend.api.features.library import db as library_db from backend.api.features.library.model import LibraryAgent from backend.copilot.model import ChatSession -from backend.data import execution as execution_db +from backend.data.db_accessors import execution_db, library_db from backend.data.execution import ExecutionStatus, GraphExecution, GraphExecutionMeta from .base import BaseTool @@ -165,10 +164,12 @@ class AgentOutputTool(BaseTool): Resolve agent from provided identifiers. Returns (library_agent, error_message). """ + lib_db = library_db() + # Priority 1: Exact library agent ID if library_agent_id: try: - agent = await library_db.get_library_agent(library_agent_id, user_id) + agent = await lib_db.get_library_agent(library_agent_id, user_id) return agent, None except Exception as e: logger.warning(f"Failed to get library agent by ID: {e}") @@ -182,7 +183,7 @@ class AgentOutputTool(BaseTool): return None, f"Agent '{store_slug}' not found in marketplace" # Find in user's library by graph_id - agent = await library_db.get_library_agent_by_graph_id(user_id, graph.id) + agent = await lib_db.get_library_agent_by_graph_id(user_id, graph.id) if not agent: return ( None, @@ -194,7 +195,7 @@ class AgentOutputTool(BaseTool): # Priority 3: Fuzzy name search in library if agent_name: try: - response = await library_db.list_library_agents( + response = await lib_db.list_library_agents( user_id=user_id, search_term=agent_name, page_size=5, @@ -228,9 +229,11 @@ class AgentOutputTool(BaseTool): Fetch execution(s) based on filters. Returns (single_execution, available_executions_meta, error_message). """ + exec_db = execution_db() + # If specific execution_id provided, fetch it directly if execution_id: - execution = await execution_db.get_graph_execution( + execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=execution_id, include_node_executions=False, @@ -240,7 +243,7 @@ class AgentOutputTool(BaseTool): return execution, [], None # Get completed executions with time filters - executions = await execution_db.get_graph_executions( + executions = await exec_db.get_graph_executions( graph_id=graph_id, user_id=user_id, statuses=[ExecutionStatus.COMPLETED], @@ -254,7 +257,7 @@ class AgentOutputTool(BaseTool): # If only one execution, fetch full details if len(executions) == 1: - full_execution = await execution_db.get_graph_execution( + full_execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=executions[0].id, include_node_executions=False, @@ -262,7 +265,7 @@ class AgentOutputTool(BaseTool): return full_execution, [], None # Multiple executions - return latest with full details, plus list of available - full_execution = await execution_db.get_graph_execution( + full_execution = await exec_db.get_graph_execution( user_id=user_id, execution_id=executions[0].id, include_node_executions=False, @@ -380,7 +383,7 @@ class AgentOutputTool(BaseTool): and not input_data.store_slug ): # Fetch execution directly to get graph_id - execution = await execution_db.get_graph_execution( + execution = await execution_db().get_graph_execution( user_id=user_id, execution_id=input_data.execution_id, include_node_executions=False, @@ -392,7 +395,7 @@ class AgentOutputTool(BaseTool): ) # Find library agent by graph_id - agent = await library_db.get_library_agent_by_graph_id( + agent = await library_db().get_library_agent_by_graph_id( user_id, execution.graph_id ) if not agent: diff --git a/autogpt_platform/backend/backend/copilot/tools/agent_search.py b/autogpt_platform/backend/backend/copilot/tools/agent_search.py index 61cdba1ef9..3c380a7150 100644 --- a/autogpt_platform/backend/backend/copilot/tools/agent_search.py +++ b/autogpt_platform/backend/backend/copilot/tools/agent_search.py @@ -4,8 +4,7 @@ import logging import re from typing import Literal -from backend.api.features.library import db as library_db -from backend.api.features.store import db as store_db +from backend.data.db_accessors import library_db, store_db from backend.util.exceptions import DatabaseError, NotFoundError from .models import ( @@ -45,8 +44,10 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N Returns: AgentInfo if found, None otherwise """ + lib_db = library_db() + try: - agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + agent = await lib_db.get_library_agent_by_graph_id(user_id, agent_id) if agent: logger.debug(f"Found library agent by graph_id: {agent.name}") return AgentInfo( @@ -71,7 +72,7 @@ async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | N ) try: - agent = await library_db.get_library_agent(agent_id, user_id) + agent = await lib_db.get_library_agent(agent_id, user_id) if agent: logger.debug(f"Found library agent by library_id: {agent.name}") return AgentInfo( @@ -133,7 +134,7 @@ async def search_agents( try: if source == "marketplace": logger.info(f"Searching marketplace for: {query}") - results = await store_db.get_store_agents(search_query=query, page_size=5) + results = await store_db().get_store_agents(search_query=query, page_size=5) for agent in results.agents: agents.append( AgentInfo( @@ -159,7 +160,7 @@ async def search_agents( if not agents: logger.info(f"Searching user library for: {query}") - results = await library_db.list_library_agents( + results = await library_db().list_library_agents( user_id=user_id, # type: ignore[arg-type] search_term=query, page_size=10, diff --git a/autogpt_platform/backend/backend/copilot/tools/customize_agent.py b/autogpt_platform/backend/backend/copilot/tools/customize_agent.py index a85a69196d..96e19656c6 100644 --- a/autogpt_platform/backend/backend/copilot/tools/customize_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/customize_agent.py @@ -3,9 +3,9 @@ import logging from typing import Any -from backend.api.features.store import db as store_db from backend.api.features.store.exceptions import AgentNotFoundError from backend.copilot.model import ChatSession +from backend.data.db_accessors import store_db as get_store_db from .agent_generator import ( AgentGeneratorNotConfiguredError, @@ -137,6 +137,8 @@ class CustomizeAgentTool(BaseTool): creator_username, agent_slug = parts + store_db = get_store_db() + # Fetch the marketplace agent details try: agent_details = await store_db.get_store_agent_details( diff --git a/autogpt_platform/backend/backend/copilot/tools/find_block.py b/autogpt_platform/backend/backend/copilot/tools/find_block.py index c26f76fc52..ef9314eb5a 100644 --- a/autogpt_platform/backend/backend/copilot/tools/find_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/find_block.py @@ -3,7 +3,6 @@ from typing import Any from prisma.enums import ContentType -from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.copilot.model import ChatSession from backend.copilot.tools.base import BaseTool, ToolResponseBase from backend.copilot.tools.models import ( @@ -14,6 +13,7 @@ from backend.copilot.tools.models import ( NoResultsResponse, ) from backend.data.block import BlockType, get_block +from backend.data.db_accessors import search logger = logging.getLogger(__name__) @@ -106,7 +106,7 @@ class FindBlockTool(BaseTool): try: # Search for blocks using hybrid search - results, total = await unified_hybrid_search( + results, total = await search().unified_hybrid_search( query=query, content_types=[ContentType.BLOCK], page=1, diff --git a/autogpt_platform/backend/backend/copilot/tools/run_agent.py b/autogpt_platform/backend/backend/copilot/tools/run_agent.py index 8f19ab86cc..46af6fbcb0 100644 --- a/autogpt_platform/backend/backend/copilot/tools/run_agent.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_agent.py @@ -5,13 +5,12 @@ from typing import Any from pydantic import BaseModel, Field, field_validator -from backend.api.features.library import db as library_db from backend.copilot.config import ChatConfig from backend.copilot.model import ChatSession from backend.copilot.tracking import track_agent_run_success, track_agent_scheduled +from backend.data.db_accessors import graph_db, library_db, user_db from backend.data.graph import GraphModel from backend.data.model import CredentialsMetaInput -from backend.data.user import get_user_by_id from backend.executor import utils as execution_utils from backend.util.clients import get_scheduler_client from backend.util.exceptions import DatabaseError, NotFoundError @@ -197,7 +196,7 @@ class RunAgentTool(BaseTool): # Priority: library_agent_id if provided if has_library_id: - library_agent = await library_db.get_library_agent( + library_agent = await library_db().get_library_agent( params.library_agent_id, user_id ) if not library_agent: @@ -206,9 +205,7 @@ class RunAgentTool(BaseTool): session_id=session_id, ) # Get the graph from the library agent - from backend.data.graph import get_graph - - graph = await get_graph( + graph = await graph_db().get_graph( library_agent.graph_id, library_agent.graph_version, user_id=user_id, @@ -519,7 +516,7 @@ class RunAgentTool(BaseTool): library_agent = await get_or_create_library_agent(graph, user_id) # Get user timezone - user = await get_user_by_id(user_id) + user = await user_db().get_user_by_id(user_id) user_timezone = get_user_timezone_or_utc(user.timezone if user else timezone) # Create schedule diff --git a/autogpt_platform/backend/backend/copilot/tools/run_block.py b/autogpt_platform/backend/backend/copilot/tools/run_block.py index dd813398de..56b664a77a 100644 --- a/autogpt_platform/backend/backend/copilot/tools/run_block.py +++ b/autogpt_platform/backend/backend/copilot/tools/run_block.py @@ -13,9 +13,9 @@ from backend.copilot.tools.find_block import ( COPILOT_EXCLUDED_BLOCK_TYPES, ) from backend.data.block import AnyBlockSchema, get_block +from backend.data.db_accessors import workspace_db from backend.data.execution import ExecutionContext from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput -from backend.data.workspace import get_or_create_workspace from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import BlockError @@ -189,7 +189,7 @@ class RunBlockTool(BaseTool): try: # Get or create user's workspace for CoPilot file operations - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Generate synthetic IDs for CoPilot context # Each chat session is treated as its own agent with one continuous run diff --git a/autogpt_platform/backend/backend/copilot/tools/search_docs.py b/autogpt_platform/backend/backend/copilot/tools/search_docs.py index 056b86d4af..95b571bdcc 100644 --- a/autogpt_platform/backend/backend/copilot/tools/search_docs.py +++ b/autogpt_platform/backend/backend/copilot/tools/search_docs.py @@ -5,7 +5,6 @@ from typing import Any from prisma.enums import ContentType -from backend.api.features.store.hybrid_search import unified_hybrid_search from backend.copilot.model import ChatSession from backend.copilot.tools.base import BaseTool from backend.copilot.tools.models import ( @@ -15,6 +14,7 @@ from backend.copilot.tools.models import ( NoResultsResponse, ToolResponseBase, ) +from backend.data.db_accessors import search logger = logging.getLogger(__name__) @@ -117,7 +117,7 @@ class SearchDocsTool(BaseTool): try: # Search using hybrid search for DOCUMENTATION content type only - results, total = await unified_hybrid_search( + results, total = await search().unified_hybrid_search( query=query, content_types=[ContentType.DOCUMENTATION], page=1, diff --git a/autogpt_platform/backend/backend/copilot/tools/utils.py b/autogpt_platform/backend/backend/copilot/tools/utils.py index 80a842bf36..b200016b02 100644 --- a/autogpt_platform/backend/backend/copilot/tools/utils.py +++ b/autogpt_platform/backend/backend/copilot/tools/utils.py @@ -3,9 +3,8 @@ import logging from typing import Any -from backend.api.features.library import db as library_db from backend.api.features.library import model as library_model -from backend.api.features.store import db as store_db +from backend.data.db_accessors import library_db, store_db from backend.data.graph import GraphModel from backend.data.model import ( Credentials, @@ -38,13 +37,14 @@ async def fetch_graph_from_store_slug( Raises: DatabaseError: If there's a database error during lookup. """ + sdb = store_db() try: - store_agent = await store_db.get_store_agent_details(username, agent_name) + store_agent = await sdb.get_store_agent_details(username, agent_name) except NotFoundError: return None, None # Get the graph from store listing version - graph = await store_db.get_available_graph( + graph = await sdb.get_available_graph( store_agent.store_listing_version_id, hide_nodes=False ) return graph, store_agent @@ -209,13 +209,13 @@ async def get_or_create_library_agent( Returns: LibraryAgent instance """ - existing = await library_db.get_library_agent_by_graph_id( + existing = await library_db().get_library_agent_by_graph_id( graph_id=graph.id, user_id=user_id ) if existing: return existing - library_agents = await library_db.create_library_agent( + library_agents = await library_db().create_library_agent( graph=graph, user_id=user_id, create_library_agents_for_sub_graphs=False, diff --git a/autogpt_platform/backend/backend/copilot/tools/workspace_files.py b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py index 50df556f03..9ecbf74052 100644 --- a/autogpt_platform/backend/backend/copilot/tools/workspace_files.py +++ b/autogpt_platform/backend/backend/copilot/tools/workspace_files.py @@ -7,7 +7,7 @@ from typing import Any, Optional from pydantic import BaseModel from backend.copilot.model import ChatSession -from backend.data.workspace import get_or_create_workspace +from backend.data.db_accessors import workspace_db from backend.util.settings import Config from backend.util.virus_scanner import scan_content_safe from backend.util.workspace import WorkspaceManager @@ -146,7 +146,7 @@ class ListWorkspaceFilesTool(BaseTool): include_all_sessions: bool = kwargs.get("include_all_sessions", False) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -280,7 +280,7 @@ class ReadWorkspaceFileTool(BaseTool): ) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -478,7 +478,7 @@ class WriteWorkspaceFileTool(BaseTool): # Virus scan await scan_content_safe(content, filename=filename) - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) @@ -577,7 +577,7 @@ class DeleteWorkspaceFileTool(BaseTool): ) try: - workspace = await get_or_create_workspace(user_id) + workspace = await workspace_db().get_or_create_workspace(user_id) # Pass session_id for session-scoped file access manager = WorkspaceManager(user_id, workspace.id, session_id) diff --git a/autogpt_platform/backend/backend/data/db_accessors.py b/autogpt_platform/backend/backend/data/db_accessors.py new file mode 100644 index 0000000000..9875cabec5 --- /dev/null +++ b/autogpt_platform/backend/backend/data/db_accessors.py @@ -0,0 +1,118 @@ +from backend.data import db + + +def chat_db(): + if db.is_connected(): + from backend.copilot import db as _chat_db + + chat_db = _chat_db + else: + from backend.util.clients import get_database_manager_async_client + + chat_db = get_database_manager_async_client() + + return chat_db + + +def graph_db(): + if db.is_connected(): + from backend.data import graph as _graph_db + + graph_db = _graph_db + else: + from backend.util.clients import get_database_manager_async_client + + graph_db = get_database_manager_async_client() + + return graph_db + + +def library_db(): + if db.is_connected(): + from backend.api.features.library import db as _library_db + + library_db = _library_db + else: + from backend.util.clients import get_database_manager_async_client + + library_db = get_database_manager_async_client() + + return library_db + + +def store_db(): + if db.is_connected(): + from backend.api.features.store import db as _store_db + + store_db = _store_db + else: + from backend.util.clients import get_database_manager_async_client + + store_db = get_database_manager_async_client() + + return store_db + + +def search(): + if db.is_connected(): + from backend.api.features.store import hybrid_search as _search + + search = _search + else: + from backend.util.clients import get_database_manager_async_client + + search = get_database_manager_async_client() + + return search + + +def execution_db(): + if db.is_connected(): + from backend.data import execution as _execution_db + + execution_db = _execution_db + else: + from backend.util.clients import get_database_manager_async_client + + execution_db = get_database_manager_async_client() + + return execution_db + + +def user_db(): + if db.is_connected(): + from backend.data import user as _user_db + + user_db = _user_db + else: + from backend.util.clients import get_database_manager_async_client + + user_db = get_database_manager_async_client() + + return user_db + + +def understanding_db(): + if db.is_connected(): + from backend.data import understanding as _understanding_db + + understanding_db = _understanding_db + else: + from backend.util.clients import get_database_manager_async_client + + understanding_db = get_database_manager_async_client() + + return understanding_db + + +def workspace_db(): + if db.is_connected(): + from backend.data import workspace as _workspace_db + + workspace_db = _workspace_db + else: + from backend.util.clients import get_database_manager_async_client + + workspace_db = get_database_manager_async_client() + + return workspace_db diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/data/db_manager.py similarity index 72% rename from autogpt_platform/backend/backend/executor/database.py rename to autogpt_platform/backend/backend/data/db_manager.py index d44439d51c..090c21ad7c 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/data/db_manager.py @@ -4,14 +4,26 @@ from typing import TYPE_CHECKING, Callable, Concatenate, ParamSpec, TypeVar, cas from backend.api.features.library.db import ( add_store_agent_to_library, + create_graph_in_library, + create_library_agent, + get_library_agent, + get_library_agent_by_graph_id, list_library_agents, + update_graph_in_library, +) +from backend.api.features.store.db import ( + get_agent, + get_available_graph, + get_store_agent_details, + get_store_agents, ) -from backend.api.features.store.db import get_store_agent_details, get_store_agents from backend.api.features.store.embeddings import ( backfill_missing_embeddings, cleanup_orphaned_embeddings, get_embedding_stats, ) +from backend.api.features.store.hybrid_search import unified_hybrid_search +from backend.copilot import db as chat_db from backend.data import db from backend.data.analytics import ( get_accuracy_trends_and_alerts, @@ -48,6 +60,7 @@ from backend.data.graph import ( get_graph_metadata, get_graph_settings, get_node, + get_store_listed_graphs, validate_graph_execution_permissions, ) from backend.data.human_review import ( @@ -67,6 +80,10 @@ from backend.data.notifications import ( remove_notifications_from_batch, ) from backend.data.onboarding import increment_onboarding_runs +from backend.data.understanding import ( + get_business_understanding, + upsert_business_understanding, +) from backend.data.user import ( get_active_user_ids_in_timerange, get_user_by_id, @@ -76,6 +93,7 @@ from backend.data.user import ( get_user_notification_preference, update_user_integrations, ) +from backend.data.workspace import get_or_create_workspace from backend.util.service import ( AppService, AppServiceClient, @@ -107,6 +125,13 @@ async def _get_credits(user_id: str) -> int: class DatabaseManager(AppService): + """Database connection pooling service. + + This service connects to the Prisma engine and exposes database + operations via RPC endpoints. It acts as a centralized connection pool + for all services that need database access. + """ + @asynccontextmanager async def lifespan(self, app: "FastAPI"): async with super().lifespan(app): @@ -142,11 +167,15 @@ class DatabaseManager(AppService): def _( f: Callable[P, R], name: str | None = None ) -> Callable[Concatenate[object, P], R]: + """ + Exposes a function as an RPC endpoint, and adds a virtual `self` param + to the function's type so it can be bound as a method. + """ if name is not None: f.__name__ = name return cast(Callable[Concatenate[object, P], R], expose(f)) - # Executions + # ============ Graph Executions ============ # get_child_graph_executions = _(get_child_graph_executions) get_graph_executions = _(get_graph_executions) get_graph_executions_count = _(get_graph_executions_count) @@ -170,36 +199,37 @@ class DatabaseManager(AppService): get_frequently_executed_graphs = _(get_frequently_executed_graphs) get_marketplace_graphs_for_monitoring = _(get_marketplace_graphs_for_monitoring) - # Graphs + # ============ Graphs ============ # get_node = _(get_node) get_graph = _(get_graph) get_connected_output_nodes = _(get_connected_output_nodes) get_graph_metadata = _(get_graph_metadata) get_graph_settings = _(get_graph_settings) + get_store_listed_graphs = _(get_store_listed_graphs) - # Credits + # ============ Credits ============ # spend_credits = _(_spend_credits, name="spend_credits") get_credits = _(_get_credits, name="get_credits") - # User + User Metadata + User Integrations + # ============ User + Integrations ============ # + get_user_by_id = _(get_user_by_id) get_user_integrations = _(get_user_integrations) update_user_integrations = _(update_user_integrations) - # User Comms - async + # ============ User Comms ============ # get_active_user_ids_in_timerange = _(get_active_user_ids_in_timerange) - get_user_by_id = _(get_user_by_id) get_user_email_by_id = _(get_user_email_by_id) get_user_email_verification = _(get_user_email_verification) get_user_notification_preference = _(get_user_notification_preference) - # Human In The Loop + # ============ Human In The Loop ============ # cancel_pending_reviews_for_execution = _(cancel_pending_reviews_for_execution) check_approval = _(check_approval) get_or_create_human_review = _(get_or_create_human_review) has_pending_reviews_for_graph_exec = _(has_pending_reviews_for_graph_exec) update_review_processed_status = _(update_review_processed_status) - # Notifications - async + # ============ Notifications ============ # clear_all_user_notification_batches = _(clear_all_user_notification_batches) create_or_add_to_user_notification_batch = _( create_or_add_to_user_notification_batch @@ -212,29 +242,56 @@ class DatabaseManager(AppService): get_user_notification_oldest_message_in_batch ) - # Library + # ============ Library ============ # list_library_agents = _(list_library_agents) add_store_agent_to_library = _(add_store_agent_to_library) + create_graph_in_library = _(create_graph_in_library) + create_library_agent = _(create_library_agent) + get_library_agent = _(get_library_agent) + get_library_agent_by_graph_id = _(get_library_agent_by_graph_id) + update_graph_in_library = _(update_graph_in_library) validate_graph_execution_permissions = _(validate_graph_execution_permissions) - # Onboarding + # ============ Onboarding ============ # increment_onboarding_runs = _(increment_onboarding_runs) - # OAuth + # ============ OAuth ============ # cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens) - # Store + # ============ Store ============ # get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) + get_agent = _(get_agent) + get_available_graph = _(get_available_graph) - # Store Embeddings + # ============ Search ============ # get_embedding_stats = _(get_embedding_stats) backfill_missing_embeddings = _(backfill_missing_embeddings) cleanup_orphaned_embeddings = _(cleanup_orphaned_embeddings) + unified_hybrid_search = _(unified_hybrid_search) - # Summary data - async + # ============ Summary Data ============ # get_user_execution_summary_data = _(get_user_execution_summary_data) + # ============ Workspace ============ # + get_or_create_workspace = _(get_or_create_workspace) + + # ============ Understanding ============ # + get_business_understanding = _(get_business_understanding) + upsert_business_understanding = _(upsert_business_understanding) + + # ============ CoPilot Chat Sessions ============ # + get_chat_session = _(chat_db.get_chat_session) + create_chat_session = _(chat_db.create_chat_session) + update_chat_session = _(chat_db.update_chat_session) + add_chat_message = _(chat_db.add_chat_message) + add_chat_messages_batch = _(chat_db.add_chat_messages_batch) + get_user_chat_sessions = _(chat_db.get_user_chat_sessions) + get_user_session_count = _(chat_db.get_user_session_count) + delete_chat_session = _(chat_db.delete_chat_session) + get_chat_session_message_count = _(chat_db.get_chat_session_message_count) + update_tool_message_content = _(chat_db.update_tool_message_content) + class DatabaseManagerClient(AppServiceClient): d = DatabaseManager @@ -296,43 +353,50 @@ class DatabaseManagerAsyncClient(AppServiceClient): def get_service_type(cls): return DatabaseManager + # ============ Graph Executions ============ # create_graph_execution = d.create_graph_execution get_child_graph_executions = d.get_child_graph_executions get_connected_output_nodes = d.get_connected_output_nodes get_latest_node_execution = d.get_latest_node_execution - get_graph = d.get_graph - get_graph_metadata = d.get_graph_metadata - get_graph_settings = d.get_graph_settings get_graph_execution = d.get_graph_execution get_graph_execution_meta = d.get_graph_execution_meta - get_node = d.get_node + get_graph_executions = d.get_graph_executions get_node_execution = d.get_node_execution get_node_executions = d.get_node_executions - get_user_by_id = d.get_user_by_id - get_user_integrations = d.get_user_integrations - upsert_execution_input = d.upsert_execution_input - upsert_execution_output = d.upsert_execution_output - get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id update_graph_execution_stats = d.update_graph_execution_stats update_node_execution_status = d.update_node_execution_status update_node_execution_status_batch = d.update_node_execution_status_batch - update_user_integrations = d.update_user_integrations + upsert_execution_input = d.upsert_execution_input + upsert_execution_output = d.upsert_execution_output + get_execution_outputs_by_node_exec_id = d.get_execution_outputs_by_node_exec_id get_execution_kv_data = d.get_execution_kv_data set_execution_kv_data = d.set_execution_kv_data - # Human In The Loop + # ============ Graphs ============ # + get_graph = d.get_graph + get_graph_metadata = d.get_graph_metadata + get_graph_settings = d.get_graph_settings + get_node = d.get_node + get_store_listed_graphs = d.get_store_listed_graphs + + # ============ User + Integrations ============ # + get_user_by_id = d.get_user_by_id + get_user_integrations = d.get_user_integrations + update_user_integrations = d.update_user_integrations + + # ============ Human In The Loop ============ # cancel_pending_reviews_for_execution = d.cancel_pending_reviews_for_execution check_approval = d.check_approval get_or_create_human_review = d.get_or_create_human_review update_review_processed_status = d.update_review_processed_status - # User Comms + # ============ User Comms ============ # get_active_user_ids_in_timerange = d.get_active_user_ids_in_timerange get_user_email_by_id = d.get_user_email_by_id get_user_email_verification = d.get_user_email_verification get_user_notification_preference = d.get_user_notification_preference - # Notifications + # ============ Notifications ============ # clear_all_user_notification_batches = d.clear_all_user_notification_batches create_or_add_to_user_notification_batch = ( d.create_or_add_to_user_notification_batch @@ -345,20 +409,49 @@ class DatabaseManagerAsyncClient(AppServiceClient): d.get_user_notification_oldest_message_in_batch ) - # Library + # ============ Library ============ # list_library_agents = d.list_library_agents add_store_agent_to_library = d.add_store_agent_to_library + create_graph_in_library = d.create_graph_in_library + create_library_agent = d.create_library_agent + get_library_agent = d.get_library_agent + get_library_agent_by_graph_id = d.get_library_agent_by_graph_id + update_graph_in_library = d.update_graph_in_library validate_graph_execution_permissions = d.validate_graph_execution_permissions - # Onboarding + # ============ Onboarding ============ # increment_onboarding_runs = d.increment_onboarding_runs - # OAuth + # ============ OAuth ============ # cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens - # Store + # ============ Store ============ # get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details + get_agent = d.get_agent + get_available_graph = d.get_available_graph - # Summary data + # ============ Search ============ # + unified_hybrid_search = d.unified_hybrid_search + + # ============ Summary Data ============ # get_user_execution_summary_data = d.get_user_execution_summary_data + + # ============ Workspace ============ # + get_or_create_workspace = d.get_or_create_workspace + + # ============ Understanding ============ # + get_business_understanding = d.get_business_understanding + upsert_business_understanding = d.upsert_business_understanding + + # ============ CoPilot Chat Sessions ============ # + get_chat_session = d.get_chat_session + create_chat_session = d.create_chat_session + update_chat_session = d.update_chat_session + add_chat_message = d.add_chat_message + add_chat_messages_batch = d.add_chat_messages_batch + get_user_chat_sessions = d.get_user_chat_sessions + get_user_session_count = d.get_user_session_count + delete_chat_session = d.delete_chat_session + get_chat_session_message_count = d.get_chat_session_message_count + update_tool_message_content = d.update_tool_message_content diff --git a/autogpt_platform/backend/backend/db.py b/autogpt_platform/backend/backend/db.py index 5c59a98a00..2661405f6d 100644 --- a/autogpt_platform/backend/backend/db.py +++ b/autogpt_platform/backend/backend/db.py @@ -1,5 +1,5 @@ from backend.app import run_processes -from backend.executor import DatabaseManager +from backend.data.db_manager import DatabaseManager def main(): diff --git a/autogpt_platform/backend/backend/executor/__init__.py b/autogpt_platform/backend/backend/executor/__init__.py index 92d8b5dc58..883bb226e6 100644 --- a/autogpt_platform/backend/backend/executor/__init__.py +++ b/autogpt_platform/backend/backend/executor/__init__.py @@ -1,11 +1,7 @@ -from .database import DatabaseManager, DatabaseManagerAsyncClient, DatabaseManagerClient from .manager import ExecutionManager from .scheduler import Scheduler __all__ = [ - "DatabaseManager", - "DatabaseManagerClient", - "DatabaseManagerAsyncClient", "ExecutionManager", "Scheduler", ] diff --git a/autogpt_platform/backend/backend/executor/activity_status_generator.py b/autogpt_platform/backend/backend/executor/activity_status_generator.py index 3bc6bcb876..d7ec7beb49 100644 --- a/autogpt_platform/backend/backend/executor/activity_status_generator.py +++ b/autogpt_platform/backend/backend/executor/activity_status_generator.py @@ -22,7 +22,7 @@ from backend.util.settings import Settings from backend.util.truncate import truncate if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/executor/automod/manager.py b/autogpt_platform/backend/backend/executor/automod/manager.py index 81001196dd..2eef4f6eca 100644 --- a/autogpt_platform/backend/backend/executor/automod/manager.py +++ b/autogpt_platform/backend/backend/executor/automod/manager.py @@ -4,7 +4,7 @@ import logging from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient from pydantic import ValidationError diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 8362dae828..440d794509 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -96,7 +96,10 @@ from .utils import ( ) if TYPE_CHECKING: - from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient + from backend.data.db_manager import ( + DatabaseManagerAsyncClient, + DatabaseManagerClient, + ) _logger = logging.getLogger(__name__) diff --git a/autogpt_platform/backend/backend/util/clients.py b/autogpt_platform/backend/backend/util/clients.py index 1cf2c6e49d..9e5a913abb 100644 --- a/autogpt_platform/backend/backend/util/clients.py +++ b/autogpt_platform/backend/backend/util/clients.py @@ -12,12 +12,15 @@ settings = Settings() if TYPE_CHECKING: from openai import AsyncOpenAI + from backend.data.db_manager import ( + DatabaseManagerAsyncClient, + DatabaseManagerClient, + ) from backend.data.execution import ( AsyncRedisExecutionEventBus, RedisExecutionEventBus, ) from backend.data.rabbitmq import AsyncRabbitMQ, SyncRabbitMQ - from backend.executor import DatabaseManagerAsyncClient, DatabaseManagerClient from backend.executor.scheduler import SchedulerClient from backend.integrations.credentials_store import IntegrationCredentialsStore from backend.notifications.notifications import NotificationManagerClient @@ -27,7 +30,7 @@ if TYPE_CHECKING: @thread_cached def get_database_manager_client() -> "DatabaseManagerClient": """Get a thread-cached DatabaseManagerClient with request retry enabled.""" - from backend.executor import DatabaseManagerClient + from backend.data.db_manager import DatabaseManagerClient from backend.util.service import get_service_client return get_service_client(DatabaseManagerClient, request_retry=True) @@ -38,7 +41,7 @@ def get_database_manager_async_client( should_retry: bool = True, ) -> "DatabaseManagerAsyncClient": """Get a thread-cached DatabaseManagerAsyncClient with request retry enabled.""" - from backend.executor import DatabaseManagerAsyncClient + from backend.data.db_manager import DatabaseManagerAsyncClient from backend.util.service import get_service_client return get_service_client(DatabaseManagerAsyncClient, request_retry=should_retry) diff --git a/autogpt_platform/backend/backend/util/test.py b/autogpt_platform/backend/backend/util/test.py index 23d7c24147..e9336dd679 100644 --- a/autogpt_platform/backend/backend/util/test.py +++ b/autogpt_platform/backend/backend/util/test.py @@ -10,6 +10,7 @@ from autogpt_libs.auth import get_user_id from backend.api.rest_api import AgentServer from backend.data import db from backend.data.block import Block, BlockSchema, initialize_blocks +from backend.data.db_manager import DatabaseManager from backend.data.execution import ( ExecutionContext, ExecutionStatus, @@ -18,7 +19,7 @@ from backend.data.execution import ( ) from backend.data.model import _BaseCredentials from backend.data.user import create_default_user -from backend.executor import DatabaseManager, ExecutionManager, Scheduler +from backend.executor import ExecutionManager, Scheduler from backend.notifications.notifications import NotificationManager log = logging.getLogger(__name__)