mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Introduces SQLModel (SQLAlchemy + Pydantic) as an ORM layer to enable future database backend switching (PostgreSQL, MySQL). All services except session_queue have been migrated to SQLModel-based implementations while keeping the existing migration system and raw SQLite connection intact for backwards compatibility. Key changes: - Add sqlmodel dependency - Define SQLModel table models for all 14 database tables - Extend SqliteDatabase with SQLAlchemy Engine and Session management - Create SQLModel implementations for 10 services (boards, images, workflows, models, users, style presets, app settings, etc.) - Session queue remains on raw SQLite (Phase 3) - Add 95 unit tests and 12 performance benchmarks - Optimize with StaticPool, expire_on_commit=False, and read-only sessions
178 lines
7.4 KiB
Python
178 lines
7.4 KiB
Python
import sqlite3
|
|
import threading
|
|
from collections.abc import Generator
|
|
from contextlib import contextmanager
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from uuid import uuid4
|
|
|
|
from sqlalchemy import event
|
|
from sqlalchemy.pool import StaticPool
|
|
from sqlmodel import Session, create_engine
|
|
|
|
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
|
|
|
|
|
class SqliteDatabase:
|
|
"""
|
|
Manages a connection to an SQLite database.
|
|
|
|
:param db_path: Path to the database file. If None, an in-memory database is used.
|
|
:param logger: Logger to use for logging.
|
|
:param verbose: Whether to log SQL statements. Provides `logger.debug` as the SQLite trace callback.
|
|
|
|
This is a light wrapper around the `sqlite3` module, providing a few conveniences:
|
|
- The database file is written to disk if it does not exist.
|
|
- Foreign key constraints are enabled by default.
|
|
- The connection is configured to use the `sqlite3.Row` row factory.
|
|
|
|
In addition to the constructor args, the instance provides the following attributes and methods:
|
|
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
|
|
- `lock`: A shared re-entrant lock, used to approximate thread safety.
|
|
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
|
|
- `get_session()`: Returns a SQLModel Session for ORM-based queries.
|
|
"""
|
|
|
|
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
|
"""Initializes the database. This is used internally by the class constructor."""
|
|
self._logger = logger
|
|
self._db_path = db_path
|
|
self._verbose = verbose
|
|
self._lock = threading.RLock()
|
|
|
|
if not self._db_path:
|
|
logger.info("Initializing in-memory database")
|
|
else:
|
|
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._logger.info(f"Initializing database at {self._db_path}")
|
|
|
|
self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
|
|
self._conn.row_factory = sqlite3.Row
|
|
|
|
if self._verbose:
|
|
self._conn.set_trace_callback(self._logger.debug)
|
|
|
|
# Enable foreign key constraints
|
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
|
|
# Enable Write-Ahead Logging (WAL) mode for better concurrency
|
|
self._conn.execute("PRAGMA journal_mode = WAL;")
|
|
|
|
# Set a busy timeout to prevent database lockups during writes
|
|
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
|
|
|
# Set up the SQLAlchemy engine for SQLModel-based queries.
|
|
# For file-based DBs, both connections point to the same file.
|
|
# For in-memory DBs, we use a named shared cache so both connections
|
|
# see the same database.
|
|
if self._db_path:
|
|
db_uri = f"sqlite:///{self._db_path}"
|
|
# StaticPool reuses a single connection — ideal for SQLite which
|
|
# serializes writes anyway. Avoids the overhead of creating a new
|
|
# connection for every Session.
|
|
self._engine = create_engine(
|
|
db_uri,
|
|
echo=self._verbose,
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
else:
|
|
# Use a shared in-memory database via URI with shared cache.
|
|
# The raw sqlite3 connection above already created ":memory:",
|
|
# so we re-create it with the shared URI instead.
|
|
shared_uri = f"file:invokeai_memdb_{uuid4().hex}?mode=memory&cache=shared"
|
|
self._conn.close()
|
|
self._conn = sqlite3.connect(shared_uri, uri=True, check_same_thread=False)
|
|
self._conn.row_factory = sqlite3.Row
|
|
if self._verbose:
|
|
self._conn.set_trace_callback(self._logger.debug)
|
|
self._conn.execute("PRAGMA foreign_keys = ON;")
|
|
|
|
self._engine = create_engine(
|
|
"sqlite+pysqlite://",
|
|
echo=self._verbose,
|
|
creator=lambda: sqlite3.connect(shared_uri, uri=True, check_same_thread=False),
|
|
poolclass=StaticPool,
|
|
)
|
|
|
|
# Apply the same PRAGMAs to all SQLAlchemy connections
|
|
@event.listens_for(self._engine, "connect")
|
|
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore
|
|
cursor = dbapi_connection.cursor()
|
|
# Note: We intentionally skip PRAGMA foreign_keys for the SQLAlchemy engine.
|
|
# Migration 22 renames the `models` table which corrupts FK references in
|
|
# `model_relationships`. The raw sqlite3 connection already enforces FKs
|
|
# for the migration phase. The SQLAlchemy engine is used only for queries
|
|
# after migrations are complete.
|
|
if self._db_path:
|
|
cursor.execute("PRAGMA journal_mode = WAL;")
|
|
cursor.execute("PRAGMA busy_timeout = 5000;")
|
|
cursor.close()
|
|
|
|
def clean(self) -> None:
|
|
"""
|
|
Cleans the database by running the VACUUM command, reporting on the freed space.
|
|
"""
|
|
# No need to clean in-memory database
|
|
if not self._db_path:
|
|
return
|
|
try:
|
|
with self._conn as conn:
|
|
initial_db_size = Path(self._db_path).stat().st_size
|
|
conn.execute("VACUUM;")
|
|
conn.commit()
|
|
final_db_size = Path(self._db_path).stat().st_size
|
|
freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
|
|
if freed_space_in_mb > 0:
|
|
self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
|
|
except Exception as e:
|
|
self._logger.error(f"Error cleaning database: {e}")
|
|
raise
|
|
|
|
@contextmanager
|
|
def get_session(self) -> Generator[Session, None, None]:
|
|
"""
|
|
Context manager that yields a SQLModel Session for write operations.
|
|
Commits on success, rolls back on exception.
|
|
|
|
Uses expire_on_commit=False so that model attributes remain accessible
|
|
after commit without triggering lazy-loads or DetachedInstanceError.
|
|
"""
|
|
with Session(self._engine, expire_on_commit=False) as session:
|
|
try:
|
|
yield session
|
|
session.commit()
|
|
except Exception:
|
|
session.rollback()
|
|
raise
|
|
|
|
@contextmanager
|
|
def get_readonly_session(self) -> Generator[Session, None, None]:
|
|
"""
|
|
Context manager that yields a lightweight read-only SQLModel Session.
|
|
|
|
Optimized for SELECT queries:
|
|
- autoflush=False: skips the automatic flush before every query
|
|
- no commit/rollback: avoids transaction overhead for reads
|
|
- expire_on_commit=False: attributes stay accessible after close
|
|
"""
|
|
with Session(self._engine, expire_on_commit=False, autoflush=False) as session:
|
|
yield session
|
|
|
|
@contextmanager
|
|
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
|
"""
|
|
Thread-safe context manager for DB work.
|
|
Acquires the RLock, yields a Cursor, then commits or rolls back.
|
|
"""
|
|
with self._lock:
|
|
cursor = self._conn.cursor()
|
|
try:
|
|
yield cursor
|
|
self._conn.commit()
|
|
except Exception:
|
|
self._conn.rollback()
|
|
raise
|
|
finally:
|
|
cursor.close()
|