mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-13 17:18:08 -05:00
Compare commits
7 Commits
fix/execut
...
swiftyos/c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
133c99211f | ||
|
|
09a9ba83c7 | ||
|
|
b50954cf0e | ||
|
|
4a7a9c5cd0 | ||
|
|
066de6c786 | ||
|
|
f3610c7755 | ||
|
|
1f6409c954 |
@@ -17,6 +17,13 @@ DATABASE_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME
|
||||
DIRECT_URL="postgresql://${DB_USER}:${DB_PASS}@${DB_HOST}:${DB_PORT}/${DB_NAME}?schema=${DB_SCHEMA}&connect_timeout=${DB_CONNECT_TIMEOUT}"
|
||||
PRISMA_SCHEMA="postgres/schema.prisma"
|
||||
|
||||
# SQLAlchemy Configuration (for gradual migration from Prisma)
|
||||
SQLALCHEMY_POOL_SIZE=10
|
||||
SQLALCHEMY_MAX_OVERFLOW=5
|
||||
SQLALCHEMY_POOL_TIMEOUT=30
|
||||
SQLALCHEMY_CONNECT_TIMEOUT=10
|
||||
SQLALCHEMY_ECHO=false
|
||||
|
||||
## ===== REQUIRED SERVICE CREDENTIALS ===== ##
|
||||
# Redis Configuration
|
||||
REDIS_HOST=localhost
|
||||
|
||||
235
autogpt_platform/backend/backend/data/sqlalchemy.py
Normal file
235
autogpt_platform/backend/backend/data/sqlalchemy.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
SQLAlchemy infrastructure for AutoGPT Platform.
|
||||
|
||||
This module provides:
|
||||
1. Async engine creation with connection pooling
|
||||
2. Session factory for dependency injection
|
||||
3. Database lifecycle management
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncEngine,
|
||||
AsyncSession,
|
||||
async_sessionmaker,
|
||||
create_async_engine,
|
||||
)
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ============================================================================
|
||||
# CONFIGURATION
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def get_database_url() -> str:
|
||||
"""
|
||||
Extract database URL from environment and convert to async format.
|
||||
|
||||
Prisma URL: postgresql://user:pass@host:port/db?schema=platform
|
||||
Async URL: postgresql+asyncpg://user:pass@host:port/db
|
||||
|
||||
Returns the async-compatible URL without schema parameter (handled separately).
|
||||
"""
|
||||
prisma_url = Config().database_url
|
||||
|
||||
# Replace postgresql:// with postgresql+asyncpg://
|
||||
async_url = prisma_url.replace("postgresql://", "postgresql+asyncpg://")
|
||||
|
||||
# Remove schema parameter (we'll handle via MetaData)
|
||||
async_url = re.sub(r"\?schema=\w+", "", async_url)
|
||||
|
||||
# Remove any remaining query parameters that might conflict
|
||||
async_url = re.sub(r"&schema=\w+", "", async_url)
|
||||
|
||||
return async_url
|
||||
|
||||
|
||||
def get_database_schema() -> str:
|
||||
"""
|
||||
Extract schema name from DATABASE_URL query parameter.
|
||||
|
||||
Returns 'platform' by default (matches Prisma configuration).
|
||||
"""
|
||||
prisma_url = Config().database_url
|
||||
match = re.search(r"schema=(\w+)", prisma_url)
|
||||
return match.group(1) if match else "platform"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENGINE CREATION
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_engine() -> AsyncEngine:
|
||||
"""
|
||||
Create async SQLAlchemy engine with connection pooling.
|
||||
|
||||
This should be called ONCE per process at startup.
|
||||
The engine is long-lived and thread-safe.
|
||||
|
||||
Connection Pool Configuration:
|
||||
- pool_size: Number of persistent connections (default: 10)
|
||||
- max_overflow: Additional connections when pool exhausted (default: 5)
|
||||
- pool_timeout: Seconds to wait for connection (default: 30)
|
||||
- pool_pre_ping: Test connections before using (prevents stale connections)
|
||||
|
||||
Total max connections = pool_size + max_overflow = 15
|
||||
"""
|
||||
url = get_database_url()
|
||||
config = Config()
|
||||
|
||||
engine = create_async_engine(
|
||||
url,
|
||||
# Connection pool configuration
|
||||
poolclass=QueuePool, # Standard connection pool
|
||||
pool_size=config.sqlalchemy_pool_size, # Persistent connections
|
||||
max_overflow=config.sqlalchemy_max_overflow, # Burst capacity
|
||||
pool_timeout=config.sqlalchemy_pool_timeout, # Wait time for connection
|
||||
pool_pre_ping=True, # Validate connections before use
|
||||
# Async configuration
|
||||
echo=config.sqlalchemy_echo, # Log SQL statements (dev/debug only)
|
||||
future=True, # Use SQLAlchemy 2.0 style
|
||||
# Connection arguments (passed to asyncpg)
|
||||
connect_args={
|
||||
"server_settings": {
|
||||
"search_path": get_database_schema(), # Use 'platform' schema
|
||||
},
|
||||
"timeout": config.sqlalchemy_connect_timeout, # Connection timeout
|
||||
},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"SQLAlchemy engine created: pool_size={config.sqlalchemy_pool_size}, "
|
||||
f"max_overflow={config.sqlalchemy_max_overflow}, "
|
||||
f"schema={get_database_schema()}"
|
||||
)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SESSION FACTORY
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def create_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
||||
"""
|
||||
Create session factory for creating AsyncSession instances.
|
||||
|
||||
The factory is configured once, then used to create sessions on-demand.
|
||||
Each session represents a single database transaction.
|
||||
|
||||
Args:
|
||||
engine: The async engine (with connection pool)
|
||||
|
||||
Returns:
|
||||
Session factory that creates properly configured AsyncSession instances
|
||||
"""
|
||||
return async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False, # Don't expire objects after commit
|
||||
autoflush=False, # Manual control over when to flush
|
||||
autocommit=False, # Explicit transaction control
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DEPENDENCY INJECTION FOR FASTAPI
|
||||
# ============================================================================
|
||||
|
||||
# Global references (set during app startup)
|
||||
_engine: AsyncEngine | None = None
|
||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||
|
||||
|
||||
def initialize(engine: AsyncEngine) -> None:
|
||||
"""
|
||||
Initialize global engine and session factory.
|
||||
|
||||
Called during FastAPI lifespan startup.
|
||||
|
||||
Args:
|
||||
engine: The async engine to use for this process
|
||||
"""
|
||||
global _engine, _session_factory
|
||||
_engine = engine
|
||||
_session_factory = create_session_factory(engine)
|
||||
logger.info("SQLAlchemy session factory initialized")
|
||||
|
||||
|
||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
FastAPI dependency that provides database session.
|
||||
|
||||
Usage in routes:
|
||||
@router.get("/users/{user_id}")
|
||||
async def get_user(
|
||||
user_id: int,
|
||||
session: AsyncSession = Depends(get_session)
|
||||
):
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
Usage in DatabaseManager RPC methods:
|
||||
@expose
|
||||
async def get_user(user_id: int):
|
||||
async with get_session() as session:
|
||||
result = await session.execute(select(User).where(User.id == user_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
Lifecycle:
|
||||
1. Request arrives
|
||||
2. FastAPI calls this function (or used as context manager)
|
||||
3. Session is created (borrows connection from pool)
|
||||
4. Session is injected into route handler
|
||||
5. Route executes (may commit/rollback)
|
||||
6. Route returns
|
||||
7. Session is closed (returns connection to pool)
|
||||
|
||||
Error handling:
|
||||
- If exception occurs, session is rolled back
|
||||
- Connection is always returned to pool (even on error)
|
||||
"""
|
||||
if _session_factory is None:
|
||||
raise RuntimeError(
|
||||
"SQLAlchemy not initialized. Call initialize() in lifespan context."
|
||||
)
|
||||
|
||||
# Create session (borrows connection from pool)
|
||||
async with _session_factory() as session:
|
||||
try:
|
||||
yield session # Inject into route handler or context manager
|
||||
# If we get here, route succeeded - commit any pending changes
|
||||
await session.commit()
|
||||
except Exception:
|
||||
# Error occurred - rollback transaction
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
# Always close session (returns connection to pool)
|
||||
await session.close()
|
||||
|
||||
|
||||
async def dispose() -> None:
|
||||
"""
|
||||
Dispose of engine and close all connections.
|
||||
|
||||
Called during FastAPI lifespan shutdown.
|
||||
Closes all connections in the pool gracefully.
|
||||
"""
|
||||
global _engine, _session_factory
|
||||
|
||||
if _engine is not None:
|
||||
logger.info("Disposing SQLAlchemy engine...")
|
||||
await _engine.dispose()
|
||||
_engine = None
|
||||
_session_factory = None
|
||||
logger.info("SQLAlchemy engine disposed")
|
||||
@@ -65,6 +65,12 @@ class UpdateTrackingModel(BaseModel, Generic[T]):
|
||||
class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
"""Config for the server."""
|
||||
|
||||
database_url: str = Field(
|
||||
default="",
|
||||
description="PostgreSQL database connection URL. "
|
||||
"Format: postgresql://user:pass@host:port/db?schema=platform&connect_timeout=60",
|
||||
)
|
||||
|
||||
num_graph_workers: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
@@ -267,6 +273,44 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="The pool size for the scheduler database connection pool",
|
||||
)
|
||||
|
||||
# SQLAlchemy Configuration
|
||||
sqlalchemy_pool_size: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Number of persistent connections in the SQLAlchemy pool. "
|
||||
"Guidelines: REST API (high traffic) 10-20, Background workers 3-5. "
|
||||
"Total across all services should not exceed PostgreSQL max_connections (default: 100).",
|
||||
)
|
||||
|
||||
sqlalchemy_max_overflow: int = Field(
|
||||
default=5,
|
||||
ge=0,
|
||||
le=50,
|
||||
description="Additional connections beyond pool_size when pool is exhausted. "
|
||||
"Total max connections = pool_size + max_overflow.",
|
||||
)
|
||||
|
||||
sqlalchemy_pool_timeout: int = Field(
|
||||
default=30,
|
||||
ge=1,
|
||||
le=300,
|
||||
description="Seconds to wait for available connection before raising error. "
|
||||
"If all connections are busy and max_overflow is reached, requests wait this long before failing.",
|
||||
)
|
||||
|
||||
sqlalchemy_connect_timeout: int = Field(
|
||||
default=10,
|
||||
ge=1,
|
||||
le=60,
|
||||
description="Seconds to wait when establishing new connection to PostgreSQL.",
|
||||
)
|
||||
|
||||
sqlalchemy_echo: bool = Field(
|
||||
default=False,
|
||||
description="Whether to log all SQL statements. Useful for debugging but very verbose. Should be False in production.",
|
||||
)
|
||||
|
||||
rabbitmq_host: str = Field(
|
||||
default="localhost",
|
||||
description="The host for the RabbitMQ server",
|
||||
|
||||
@@ -14,6 +14,7 @@ aiohttp = "^3.10.0"
|
||||
aiodns = "^3.5.0"
|
||||
anthropic = "^0.59.0"
|
||||
apscheduler = "^3.11.1"
|
||||
asyncpg = "^0.29.0"
|
||||
autogpt-libs = { path = "../autogpt_libs", develop = true }
|
||||
bleach = { extras = ["css"], version = "^6.2.0" }
|
||||
click = "^8.2.0"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
datasource db {
|
||||
provider = "postgresql"
|
||||
url = env("DATABASE_URL")
|
||||
directUrl = env("DIRECT_URL")
|
||||
provider = "postgresql"
|
||||
url = env("DIRECT_URL")
|
||||
}
|
||||
|
||||
generator client {
|
||||
@@ -665,7 +664,7 @@ view StoreAgent {
|
||||
agent_video String?
|
||||
agent_image String[]
|
||||
|
||||
featured Boolean @default(false)
|
||||
featured Boolean @default(false)
|
||||
creator_username String?
|
||||
creator_avatar String?
|
||||
sub_heading String
|
||||
@@ -675,8 +674,8 @@ view StoreAgent {
|
||||
runs Int
|
||||
rating Float
|
||||
versions String[]
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
is_available Boolean @default(true)
|
||||
useForOnboarding Boolean @default(false)
|
||||
|
||||
// Materialized views used (refreshed every 15 minutes via pg_cron):
|
||||
// - mv_agent_run_counts - Pre-aggregated agent execution counts by agentGraphId
|
||||
|
||||
Reference in New Issue
Block a user