Add SQLModel ORM layer alongside raw SQLite for database portability

Introduces SQLModel (SQLAlchemy + Pydantic) as an ORM layer to enable
future database backend switching (PostgreSQL, MySQL). All services
except session_queue have been migrated to SQLModel-based implementations
while keeping the existing migration system and raw SQLite connection
intact for backwards compatibility.

Key changes:
- Add sqlmodel dependency
- Define SQLModel table models for all 14 database tables
- Extend SqliteDatabase with SQLAlchemy Engine and Session management
- Create SQLModel implementations for 10 services (boards, images,
  workflows, models, users, style presets, app settings, etc.)
- Session queue remains on raw SQLite (Phase 3)
- Add 95 unit tests and 12 performance benchmarks
- Optimize with StaticPool, expire_on_commit=False, and read-only sessions
This commit is contained in:
Alexander Eichhorn
2026-04-19 04:19:47 +02:00
parent 63a1fe5299
commit 1a06fe0d8d
26 changed files with 3491 additions and 22 deletions

View File

@@ -5,19 +5,21 @@ from logging import Logger
import torch
from invokeai.app.services.app_settings import AppSettingsService
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
from invokeai.app.services.auth.token_service import set_jwt_secret
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
from invokeai.app.services.board_images.board_images_default import BoardImagesService
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
from invokeai.app.services.boards.boards_default import BoardService
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
ClientStatePersistenceSqlModel,
)
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.download.download_default import DownloadQueueService
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
from invokeai.app.services.images.images_default import ImageService
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
from invokeai.app.services.invocation_services import InvocationServices
@@ -25,9 +27,9 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import Invo
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
SqliteModelRelationshipRecordStorage,
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
SqlModelModelRelationshipRecordStorage,
)
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
from invokeai.app.services.names.names_default import SimpleNameService
@@ -40,10 +42,10 @@ from invokeai.app.services.session_processor.session_processor_default import (
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
from invokeai.app.services.urls.urls_default import LocalUrlService
from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
from invokeai.app.services.workflow_records.workflow_records_sqlmodel import SqlModelWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
AnimaConditioningInfo,
@@ -107,7 +109,7 @@ class ApiDependencies:
db = init_db(config=config, logger=logger, image_files=image_files)
# Initialize JWT secret from database
app_settings = AppSettingsService(db=db)
app_settings = AppSettingsServiceSqlModel(db=db)
jwt_secret = app_settings.get_jwt_secret()
set_jwt_secret(jwt_secret)
logger.info("JWT secret loaded from database")
@@ -115,13 +117,13 @@ class ApiDependencies:
configuration = config
logger = logger
board_image_records = SqliteBoardImageRecordStorage(db=db)
board_image_records = SqlModelBoardImageRecordStorage(db=db)
board_images = BoardImagesService()
board_records = SqliteBoardRecordStorage(db=db)
board_records = SqlModelBoardRecordStorage(db=db)
boards = BoardService()
events = FastAPIEventService(event_handler_id, loop=loop)
bulk_download = BulkDownloadService()
image_records = SqliteImageRecordStorage(db=db)
image_records = SqlModelImageRecordStorage(db=db)
images = ImageService()
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
tensors = ObjectSerializerForwardCache(
@@ -152,23 +154,23 @@ class ApiDependencies:
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,
model_record_service=ModelRecordServiceSQL(db=db, logger=logger),
model_record_service=ModelRecordServiceSqlModel(db=db, logger=logger),
download_queue=download_queue_service,
events=events,
)
model_relationships = ModelRelationshipsService()
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
model_relationship_records = SqlModelModelRelationshipRecordStorage(db=db)
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db)
session_queue = SqliteSessionQueue(db=db) # Stays raw SQL (Phase 3)
urls = LocalUrlService()
workflow_records = SqliteWorkflowRecordsStorage(db=db)
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
client_state_persistence = ClientStatePersistenceSqlite(db=db)
users = UserService(db=db)
client_state_persistence = ClientStatePersistenceSqlModel(db=db)
users = UserServiceSqlModel(db=db)
services = InvocationServices(
board_image_records=board_image_records,

View File

@@ -0,0 +1,37 @@
from typing import Optional
from invokeai.app.services.shared.sqlite.models import AppSettingTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class AppSettingsServiceSqlModel:
"""SQLModel implementation for application-level settings."""
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
def get(self, key: str) -> Optional[str]:
try:
with self._db.get_readonly_session() as session:
row = session.get(AppSettingTable, key)
return row.value if row else None
except Exception:
return None
def set(self, key: str, value: str) -> None:
with self._db.get_session() as session:
existing = session.get(AppSettingTable, key)
if existing is not None:
existing.value = value
session.add(existing)
else:
session.add(AppSettingTable(key=key, value=value))
def get_jwt_secret(self) -> str:
secret = self.get("jwt_secret")
if secret is None:
raise RuntimeError(
"JWT secret not found in database. This should have been created during database migration. "
"Please ensure database migrations have been run successfully."
)
return secret

View File

@@ -0,0 +1,145 @@
from typing import Optional
from sqlalchemy import func
from sqlmodel import col, select
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
ASSETS_CATEGORIES,
IMAGE_CATEGORIES,
ImageCategory,
ImageRecord,
deserialize_image_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.models import BoardImageTable, ImageTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqlModelBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def add_image_to_board(self, board_id: str, image_name: str) -> None:
with self._db.get_session() as session:
existing = session.get(BoardImageTable, image_name)
if existing is not None:
existing.board_id = board_id
session.add(existing)
else:
session.add(BoardImageTable(board_id=board_id, image_name=image_name))
def remove_image_from_board(self, image_name: str) -> None:
with self._db.get_session() as session:
existing = session.get(BoardImageTable, image_name)
if existing is not None:
session.delete(existing)
def get_images_for_board(
self,
board_id: str,
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.get_readonly_session() as session:
# Join board_images with images
stmt = (
select(ImageTable)
.join(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
.where(col(BoardImageTable.board_id) == board_id)
.order_by(col(BoardImageTable.updated_at).desc())
)
results = session.exec(stmt).all()
images = [deserialize_image_record(_image_to_dict(r)) for r in results]
# Total count of all images
count_stmt = select(func.count()).select_from(ImageTable)
count = session.exec(count_stmt).one()
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def get_all_board_image_names_for_board(
self,
board_id: str,
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
with self._db.get_readonly_session() as session:
stmt = select(ImageTable.image_name).outerjoin(
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
)
if board_id == "none":
stmt = stmt.where(col(BoardImageTable.board_id).is_(None))
else:
stmt = stmt.where(col(BoardImageTable.board_id) == board_id)
if categories is not None:
category_strings = [c.value for c in set(categories)]
stmt = stmt.where(col(ImageTable.image_category).in_(category_strings))
if is_intermediate is not None:
stmt = stmt.where(col(ImageTable.is_intermediate) == is_intermediate)
results = session.exec(stmt).all()
return list(results)
def get_board_for_image(self, image_name: str) -> Optional[str]:
with self._db.get_readonly_session() as session:
row = session.get(BoardImageTable, image_name)
if row is None:
return None
return row.board_id
def get_image_count_for_board(self, board_id: str) -> int:
category_strings = [c.value for c in set(IMAGE_CATEGORIES)]
with self._db.get_readonly_session() as session:
stmt = (
select(func.count())
.select_from(BoardImageTable)
.join(ImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
.where(
col(ImageTable.is_intermediate) == False, # noqa: E712
col(ImageTable.image_category).in_(category_strings),
col(BoardImageTable.board_id) == board_id,
)
)
count = session.exec(stmt).one()
return count
def get_asset_count_for_board(self, board_id: str) -> int:
category_strings = [c.value for c in set(ASSETS_CATEGORIES)]
with self._db.get_readonly_session() as session:
stmt = (
select(func.count())
.select_from(BoardImageTable)
.join(ImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
.where(
col(ImageTable.is_intermediate) == False, # noqa: E712
col(ImageTable.image_category).in_(category_strings),
col(BoardImageTable.board_id) == board_id,
)
)
count = session.exec(stmt).one()
return count
def _image_to_dict(row: ImageTable) -> dict:
"""Convert an ImageTable row to a dict compatible with deserialize_image_record."""
return {
"image_name": row.image_name,
"image_origin": row.image_origin,
"image_category": row.image_category,
"width": row.width,
"height": row.height,
"session_id": row.session_id,
"node_id": row.node_id,
"metadata": row.metadata_,
"is_intermediate": row.is_intermediate,
"created_at": row.created_at,
"updated_at": row.updated_at,
"deleted_at": row.deleted_at,
"starred": row.starred,
"has_workflow": row.has_workflow,
}

View File

@@ -0,0 +1,177 @@
from sqlalchemy import func
from sqlmodel import col, select
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
from invokeai.app.services.board_records.board_records_common import (
BoardChanges,
BoardRecord,
BoardRecordDeleteException,
BoardRecordNotFoundException,
BoardRecordOrderBy,
BoardRecordSaveException,
BoardVisibility,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.models import BoardTable
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.util.misc import uuid_string
def _to_record(row: BoardTable) -> BoardRecord:
"""Convert a SQLModel BoardTable row to a BoardRecord pydantic model.
Must be called while the row is still bound to an active Session.
"""
try:
visibility = BoardVisibility(row.board_visibility)
except ValueError:
visibility = BoardVisibility.Private
return BoardRecord(
board_id=row.board_id,
board_name=row.board_name,
user_id=row.user_id,
cover_image_name=row.cover_image_name,
created_at=row.created_at,
updated_at=row.updated_at,
deleted_at=row.deleted_at,
archived=row.archived,
board_visibility=visibility,
)
class SqlModelBoardRecordStorage(BoardRecordStorageBase):
"""Board record storage using SQLModel."""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def delete(self, board_id: str) -> None:
with self._db.get_session() as session:
try:
board = session.get(BoardTable, board_id)
if board:
session.delete(board)
except Exception as e:
raise BoardRecordDeleteException from e
def save(self, board_name: str, user_id: str) -> BoardRecord:
board_id = uuid_string()
board = BoardTable(board_id=board_id, board_name=board_name, user_id=user_id)
with self._db.get_session() as session:
try:
session.add(board)
session.flush()
return _to_record(board)
except Exception as e:
raise BoardRecordSaveException from e
def get(self, board_id: str) -> BoardRecord:
with self._db.get_readonly_session() as session:
board = session.get(BoardTable, board_id)
if board is None:
raise BoardRecordNotFoundException
return _to_record(board)
def update(self, board_id: str, changes: BoardChanges) -> BoardRecord:
with self._db.get_session() as session:
try:
board = session.get(BoardTable, board_id)
if board is None:
raise BoardRecordNotFoundException
if changes.board_name is not None:
board.board_name = changes.board_name
if changes.cover_image_name is not None:
board.cover_image_name = changes.cover_image_name
if changes.archived is not None:
board.archived = changes.archived
if changes.board_visibility is not None:
board.board_visibility = changes.board_visibility.value
session.add(board)
session.flush()
return _to_record(board)
except BoardRecordNotFoundException:
raise
except Exception as e:
raise BoardRecordSaveException from e
def get_many(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
offset: int = 0,
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
with self._db.get_readonly_session() as session:
# Build filter conditions
conditions = []
if not is_admin:
conditions.append(
(col(BoardTable.user_id) == user_id) | (col(BoardTable.board_visibility).in_(["shared", "public"]))
)
if not include_archived:
conditions.append(col(BoardTable.archived) == False) # noqa: E712
# Count query
count_stmt = select(func.count()).select_from(BoardTable)
for cond in conditions:
count_stmt = count_stmt.where(cond)
total = session.exec(count_stmt).one()
# Data query
stmt = select(BoardTable)
for cond in conditions:
stmt = stmt.where(cond)
# Apply ordering
order_col = (
col(BoardTable.created_at) if order_by == BoardRecordOrderBy.CreatedAt else col(BoardTable.board_name)
)
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
stmt = stmt.offset(offset).limit(limit)
results = session.exec(stmt).all()
boards = [_to_record(r) for r in results]
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=total)
def get_all(
self,
user_id: str,
is_admin: bool,
order_by: BoardRecordOrderBy,
direction: SQLiteDirection,
include_archived: bool = False,
) -> list[BoardRecord]:
with self._db.get_readonly_session() as session:
stmt = select(BoardTable)
if not is_admin:
stmt = stmt.where(
(col(BoardTable.user_id) == user_id) | (col(BoardTable.board_visibility).in_(["shared", "public"]))
)
if not include_archived:
stmt = stmt.where(col(BoardTable.archived) == False) # noqa: E712
# Apply ordering
if order_by == BoardRecordOrderBy.Name:
order_col = col(BoardTable.board_name)
else:
order_col = col(BoardTable.created_at)
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
results = session.exec(stmt).all()
boards = [_to_record(r) for r in results]
return boards

View File

@@ -0,0 +1,41 @@
from sqlmodel import col, select
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.models import ClientStateTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class ClientStatePersistenceSqlModel(ClientStatePersistenceABC):
"""SQLModel implementation for client state persistence."""
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
def set_by_key(self, user_id: str, key: str, value: str) -> str:
with self._db.get_session() as session:
existing = session.get(ClientStateTable, (user_id, key))
if existing is not None:
existing.value = value
session.add(existing)
else:
session.add(ClientStateTable(user_id=user_id, key=key, value=value))
return value
def get_by_key(self, user_id: str, key: str) -> str | None:
with self._db.get_readonly_session() as session:
row = session.get(ClientStateTable, (user_id, key))
if row is None:
return None
return row.value
def delete(self, user_id: str) -> None:
with self._db.get_session() as session:
stmt = select(ClientStateTable).where(col(ClientStateTable.user_id) == user_id)
rows = session.exec(stmt).all()
for row in rows:
session.delete(row)

View File

@@ -0,0 +1,367 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import func
from sqlmodel import col, select
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
ImageRecordNotFoundException,
ImageRecordSaveException,
ResourceOrigin,
deserialize_image_record,
)
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
from invokeai.app.services.shared.sqlite.models import BoardImageTable, ImageTable
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
def _to_dict(row: ImageTable) -> dict:
return {
"image_name": row.image_name,
"image_origin": row.image_origin,
"image_category": row.image_category,
"width": row.width,
"height": row.height,
"session_id": row.session_id,
"node_id": row.node_id,
"metadata": row.metadata_,
"is_intermediate": row.is_intermediate,
"created_at": row.created_at,
"updated_at": row.updated_at,
"deleted_at": row.deleted_at,
"starred": row.starred,
"has_workflow": row.has_workflow,
}
class SqlModelImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def get(self, image_name: str) -> ImageRecord:
with self._db.get_readonly_session() as session:
row = session.get(ImageTable, image_name)
if row is None:
raise ImageRecordNotFoundException
return deserialize_image_record(_to_dict(row))
def get_user_id(self, image_name: str) -> Optional[str]:
with self._db.get_readonly_session() as session:
row = session.get(ImageTable, image_name)
if row is None:
return None
return row.user_id
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
with self._db.get_readonly_session() as session:
row = session.get(ImageTable, image_name)
if row is None:
raise ImageRecordNotFoundException
if row.metadata_ is None:
return None
return MetadataFieldValidator.validate_json(row.metadata_)
def update(self, image_name: str, changes: ImageRecordChanges) -> None:
with self._db.get_session() as session:
try:
row = session.get(ImageTable, image_name)
if row is None:
raise ImageRecordNotFoundException
if changes.image_category is not None:
row.image_category = changes.image_category.value
if changes.session_id is not None:
row.session_id = changes.session_id
if changes.is_intermediate is not None:
row.is_intermediate = changes.is_intermediate
if changes.starred is not None:
row.starred = changes.starred
session.add(row)
except ImageRecordNotFoundException:
raise
except Exception as e:
raise ImageRecordSaveException from e
def get_many(
self,
offset: int = 0,
limit: int = 10,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> OffsetPaginatedResults[ImageRecord]:
with self._db.get_readonly_session() as session:
# Base query with left join to board_images
stmt = select(ImageTable).outerjoin(
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
)
count_stmt = (
select(func.count())
.select_from(ImageTable)
.outerjoin(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
)
# Apply filters
stmt, count_stmt = self._apply_image_filters(
stmt,
count_stmt,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
user_id,
is_admin,
)
# Count
total = session.exec(count_stmt).one()
# Ordering
if starred_first:
stmt = stmt.order_by(
col(ImageTable.starred).desc(),
col(ImageTable.created_at).desc()
if order_dir == SQLiteDirection.Descending
else col(ImageTable.created_at).asc(),
)
else:
stmt = stmt.order_by(
col(ImageTable.created_at).desc()
if order_dir == SQLiteDirection.Descending
else col(ImageTable.created_at).asc(),
)
stmt = stmt.limit(limit).offset(offset)
results = session.exec(stmt).all()
images = [deserialize_image_record(_to_dict(r)) for r in results]
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=total)
def delete(self, image_name: str) -> None:
with self._db.get_session() as session:
try:
row = session.get(ImageTable, image_name)
if row is not None:
session.delete(row)
except Exception as e:
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
with self._db.get_session() as session:
try:
stmt = select(ImageTable).where(col(ImageTable.image_name).in_(image_names))
rows = session.exec(stmt).all()
for row in rows:
session.delete(row)
except Exception as e:
raise ImageRecordDeleteException from e
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
with self._db.get_readonly_session() as session:
stmt = (
select(func.count())
.select_from(ImageTable)
.where(
col(ImageTable.is_intermediate) == True # noqa: E712
)
)
if user_id is not None:
stmt = stmt.where(col(ImageTable.user_id) == user_id)
count = session.exec(stmt).one()
return count
def delete_intermediates(self) -> list[str]:
with self._db.get_session() as session:
try:
stmt = select(ImageTable).where(col(ImageTable.is_intermediate) == True) # noqa: E712
rows = session.exec(stmt).all()
names = [r.image_name for r in rows]
for row in rows:
session.delete(row)
except Exception as e:
raise ImageRecordDeleteException from e
return names
def save(
self,
image_name: str,
image_origin: ResourceOrigin,
image_category: ImageCategory,
width: int,
height: int,
has_workflow: bool,
is_intermediate: Optional[bool] = False,
starred: Optional[bool] = False,
session_id: Optional[str] = None,
node_id: Optional[str] = None,
metadata: Optional[str] = None,
user_id: Optional[str] = None,
) -> datetime:
row = ImageTable(
image_name=image_name,
image_origin=image_origin.value,
image_category=image_category.value,
width=width,
height=height,
session_id=session_id,
node_id=node_id,
metadata_=metadata,
is_intermediate=is_intermediate or False,
starred=starred or False,
has_workflow=has_workflow,
user_id=user_id or "system",
)
with self._db.get_session() as session:
try:
session.add(row)
session.flush()
# With expire_on_commit=False, row.created_at is still accessible
return (
row.created_at
if isinstance(row.created_at, datetime)
else datetime.fromisoformat(str(row.created_at))
)
except Exception as e:
raise ImageRecordSaveException from e
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
with self._db.get_readonly_session() as session:
stmt = (
select(ImageTable)
.join(BoardImageTable, col(ImageTable.image_name) == col(BoardImageTable.image_name))
.where(
col(BoardImageTable.board_id) == board_id,
col(ImageTable.is_intermediate) == False, # noqa: E712
)
.order_by(col(ImageTable.starred).desc(), col(ImageTable.created_at).desc())
.limit(1)
)
row = session.exec(stmt).first()
if row is None:
return None
return deserialize_image_record(_to_dict(row))
def get_image_names(
self,
starred_first: bool = True,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
image_origin: Optional[ResourceOrigin] = None,
categories: Optional[list[ImageCategory]] = None,
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
user_id: Optional[str] = None,
is_admin: bool = False,
) -> ImageNamesResult:
with self._db.get_readonly_session() as session:
# Base query
stmt = select(ImageTable.image_name).outerjoin(
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
)
# Dummy count stmt for filter reuse (we won't use it here)
count_stmt = (
select(func.count())
.select_from(ImageTable)
.outerjoin(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
)
stmt, count_stmt = self._apply_image_filters(
stmt,
count_stmt,
image_origin,
categories,
is_intermediate,
board_id,
search_term,
user_id,
is_admin,
)
# Starred count
starred_count = 0
if starred_first:
starred_stmt = count_stmt.where(col(ImageTable.starred) == True) # noqa: E712
starred_count = session.exec(starred_stmt).one()
# Ordering
if starred_first:
stmt = stmt.order_by(
col(ImageTable.starred).desc(),
col(ImageTable.created_at).desc()
if order_dir == SQLiteDirection.Descending
else col(ImageTable.created_at).asc(),
)
else:
stmt = stmt.order_by(
col(ImageTable.created_at).desc()
if order_dir == SQLiteDirection.Descending
else col(ImageTable.created_at).asc(),
)
results = session.exec(stmt).all()
return ImageNamesResult(
image_names=list(results),
starred_count=starred_count,
total_count=len(results),
)
@staticmethod
def _apply_image_filters(
stmt, count_stmt, image_origin, categories, is_intermediate, board_id, search_term, user_id, is_admin
):
"""Apply common filters to both data and count queries."""
if image_origin is not None:
cond = col(ImageTable.image_origin) == image_origin.value
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if categories is not None:
category_strings = [c.value for c in set(categories)]
cond = col(ImageTable.image_category).in_(category_strings)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if is_intermediate is not None:
cond = col(ImageTable.is_intermediate) == is_intermediate
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if board_id == "none":
cond = col(BoardImageTable.board_id).is_(None)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if user_id is not None and not is_admin:
user_cond = col(ImageTable.user_id) == user_id
stmt = stmt.where(user_cond)
count_stmt = count_stmt.where(user_cond)
elif board_id is not None:
cond = col(BoardImageTable.board_id) == board_id
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if search_term:
term = f"%{search_term.lower()}%"
cond = col(ImageTable.metadata_).like(term) | col(ImageTable.created_at).like(term)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
return stmt, count_stmt

View File

@@ -0,0 +1,235 @@
"""SQLModel implementation of ModelRecordServiceBase."""
import json
import logging
from math import ceil
from pathlib import Path
from typing import List, Optional, Union
import pydantic
from sqlalchemy import func, literal_column
from sqlmodel import select
from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
ModelRecordServiceBase,
ModelSummary,
UnknownModelException,
)
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.models import ModelTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
# Mapping from ModelRecordOrderBy to column expressions
_ORDER_COLS = {
ModelRecordOrderBy.Default: "type, base, name, format",
ModelRecordOrderBy.Type: "type",
ModelRecordOrderBy.Base: "base",
ModelRecordOrderBy.Name: "name",
ModelRecordOrderBy.Format: "format",
}
class ModelRecordServiceSqlModel(ModelRecordServiceBase):
"""SQLModel implementation of ModelConfigStore."""
def __init__(self, db: SqliteDatabase, logger: logging.Logger):
super().__init__()
self._db = db
self._logger = logger
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
row = ModelTable(id=config.key, config=config.model_dump_json())
try:
with self._db.get_session() as session:
session.add(row)
except Exception as e:
err_str = str(e)
if "UNIQUE constraint failed" in err_str:
if "models.path" in err_str:
msg = f"A model with path '{config.path}' is already installed"
elif "models.name" in err_str:
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
else:
msg = f"A model with key '{config.key}' is already installed"
raise DuplicateModelException(msg) from e
raise
return self.get_model(config.key)
def del_model(self, key: str) -> None:
with self._db.get_session() as session:
row = session.get(ModelTable, key)
if row is None:
raise UnknownModelException("model not found")
session.delete(row)
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
record = self.get_model(key)
if allow_class_change:
record_as_dict = record.model_dump()
for field_name in changes.model_fields_set:
record_as_dict[field_name] = getattr(changes, field_name)
record = ModelConfigFactory.from_dict(record_as_dict)
else:
for field_name in changes.model_fields_set:
setattr(record, field_name, getattr(changes, field_name))
json_serialized = record.model_dump_json()
with self._db.get_session() as session:
row = session.get(ModelTable, key)
if row is None:
raise UnknownModelException("model not found")
row.config = json_serialized
session.add(row)
return self.get_model(key)
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
if key != new_config.key:
raise ValueError("key does not match new_config.key")
with self._db.get_session() as session:
row = session.get(ModelTable, key)
if row is None:
raise UnknownModelException("model not found")
row.config = new_config.model_dump_json()
session.add(row)
return self.get_model(key)
def get_model(self, key: str) -> AnyModelConfig:
with self._db.get_readonly_session() as session:
row = session.get(ModelTable, key)
if row is None:
raise UnknownModelException("model not found")
return ModelConfigFactory.from_dict(json.loads(row.config))
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
with self._db.get_readonly_session() as session:
stmt = select(ModelTable).where(literal_column("hash") == hash)
row = session.exec(stmt).first()
if row is None:
raise UnknownModelException("model not found")
return ModelConfigFactory.from_dict(json.loads(row.config))
def exists(self, key: str) -> bool:
with self._db.get_readonly_session() as session:
row = session.get(ModelTable, key)
return row is not None
def search_by_attr(
self,
model_name: Optional[str] = None,
base_model: Optional[BaseModelType] = None,
model_type: Optional[ModelType] = None,
model_format: Optional[ModelFormat] = None,
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
) -> List[AnyModelConfig]:
with self._db.get_readonly_session() as session:
stmt = select(ModelTable)
if model_name:
stmt = stmt.where(literal_column("name") == model_name)
if base_model:
stmt = stmt.where(literal_column("base") == base_model)
if model_type:
stmt = stmt.where(literal_column("type") == model_type)
if model_format:
stmt = stmt.where(literal_column("format") == model_format)
# Apply ordering via the generated columns
if order_by == ModelRecordOrderBy.Default:
stmt = stmt.order_by(
literal_column("type"),
literal_column("base"),
literal_column("name"),
literal_column("format"),
)
elif order_by == ModelRecordOrderBy.Type:
stmt = stmt.order_by(literal_column("type"))
elif order_by == ModelRecordOrderBy.Base:
stmt = stmt.order_by(literal_column("base"))
elif order_by == ModelRecordOrderBy.Name:
stmt = stmt.order_by(literal_column("name"))
elif order_by == ModelRecordOrderBy.Format:
stmt = stmt.order_by(literal_column("format"))
rows = session.exec(stmt).all()
# Extract config strings while still in the session
config_strings = [row.config for row in rows]
results: list[AnyModelConfig] = []
for config_str in config_strings:
try:
model_config = ModelConfigFactory.from_dict(json.loads(config_str))
except pydantic.ValidationError as e:
config_preview = f"{config_str[:64]}..." if len(config_str) > 64 else config_str
try:
name = json.loads(config_str).get("name", "<unknown>")
except Exception:
name = "<unknown>"
self._logger.warning(
f"Skipping invalid model config in the database with name {name}. ({config_preview})"
)
self._logger.warning(f"Validation error: {e}")
else:
results.append(model_config)
return results
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
with self._db.get_readonly_session() as session:
stmt = select(ModelTable).where(literal_column("path") == str(path))
rows = session.exec(stmt).all()
configs = [r.config for r in rows]
return [ModelConfigFactory.from_dict(json.loads(c)) for c in configs]
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
with self._db.get_readonly_session() as session:
stmt = select(ModelTable).where(literal_column("hash") == hash)
rows = session.exec(stmt).all()
configs = [r.config for r in rows]
return [ModelConfigFactory.from_dict(json.loads(c)) for c in configs]
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
with self._db.get_readonly_session() as session:
# Total count
count_stmt = select(func.count()).select_from(ModelTable)
total = session.exec(count_stmt).one()
# Data query
stmt = select(ModelTable)
if order_by == ModelRecordOrderBy.Default:
stmt = stmt.order_by(
literal_column("type"),
literal_column("base"),
literal_column("name"),
literal_column("format"),
)
elif order_by == ModelRecordOrderBy.Type:
stmt = stmt.order_by(literal_column("type"))
elif order_by == ModelRecordOrderBy.Base:
stmt = stmt.order_by(literal_column("base"))
elif order_by == ModelRecordOrderBy.Name:
stmt = stmt.order_by(literal_column("name"))
elif order_by == ModelRecordOrderBy.Format:
stmt = stmt.order_by(literal_column("format"))
stmt = stmt.limit(per_page).offset(page * per_page)
rows = session.exec(stmt).all()
configs = [r.config for r in rows]
items = [ModelSummary.model_validate({"config": c}) for c in configs]
return PaginatedResults(
page=page,
pages=ceil(total / per_page),
per_page=per_page,
total=total,
items=items,
)

View File

@@ -0,0 +1,54 @@
from sqlmodel import col, select
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
from invokeai.app.services.shared.sqlite.models import ModelRelationshipTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqlModelModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
with self._db.get_session() as session:
existing = session.get(ModelRelationshipTable, (a, b))
if existing is None:
session.add(ModelRelationshipTable(model_key_1=a, model_key_2=b))
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
a, b = sorted([model_key_1, model_key_2])
with self._db.get_session() as session:
existing = session.get(ModelRelationshipTable, (a, b))
if existing is not None:
session.delete(existing)
def get_related_model_keys(self, model_key: str) -> list[str]:
with self._db.get_readonly_session() as session:
# Get keys where model_key appears in either column
stmt1 = select(ModelRelationshipTable.model_key_2).where(
col(ModelRelationshipTable.model_key_1) == model_key
)
stmt2 = select(ModelRelationshipTable.model_key_1).where(
col(ModelRelationshipTable.model_key_2) == model_key
)
results1 = session.exec(stmt1).all()
results2 = session.exec(stmt2).all()
return list(set(results1 + results2))
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
with self._db.get_readonly_session() as session:
stmt1 = select(ModelRelationshipTable.model_key_2).where(
col(ModelRelationshipTable.model_key_1).in_(model_keys)
)
stmt2 = select(ModelRelationshipTable.model_key_1).where(
col(ModelRelationshipTable.model_key_2).in_(model_keys)
)
results1 = session.exec(stmt1).all()
results2 = session.exec(stmt2).all()
return list(set(results1 + results2))

View File

@@ -0,0 +1,251 @@
"""SQLModel table definitions for the InvokeAI database.
These models mirror the schema created by the raw SQL migrations.
The migrations remain the source of truth for schema changes —
these models are used only for querying via SQLModel/SQLAlchemy.
"""
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, String
from sqlalchemy.schema import FetchedValue
from sqlmodel import Field, SQLModel
# --- boards ---
class BoardTable(SQLModel, table=True):
"""Mirrors the `boards` table."""
__tablename__ = "boards"
board_id: str = Field(primary_key=True)
board_name: str
cover_image_name: Optional[str] = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
deleted_at: Optional[datetime] = Field(default=None)
archived: bool = Field(default=False)
user_id: str = Field(default="system")
is_public: bool = Field(default=False)
board_visibility: str = Field(default="private")
class BoardImageTable(SQLModel, table=True):
"""Mirrors the `board_images` junction table."""
__tablename__ = "board_images"
image_name: str = Field(primary_key=True)
board_id: str = Field(foreign_key="boards.board_id")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
deleted_at: Optional[datetime] = Field(default=None)
class SharedBoardTable(SQLModel, table=True):
"""Mirrors the `shared_boards` table."""
__tablename__ = "shared_boards"
board_id: str = Field(primary_key=True, foreign_key="boards.board_id")
user_id: str = Field(primary_key=True, foreign_key="users.user_id")
can_edit: bool = Field(default=False)
shared_at: datetime = Field(default_factory=datetime.utcnow)
# --- images ---
class ImageTable(SQLModel, table=True):
"""Mirrors the `images` table."""
__tablename__ = "images"
image_name: str = Field(primary_key=True)
image_origin: str
image_category: str
width: int
height: int
session_id: Optional[str] = Field(default=None)
node_id: Optional[str] = Field(default=None)
metadata_: Optional[str] = Field(default=None, sa_column_kwargs={"name": "metadata"})
is_intermediate: bool = Field(default=False)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
deleted_at: Optional[datetime] = Field(default=None)
starred: bool = Field(default=False)
has_workflow: bool = Field(default=False)
user_id: str = Field(default="system")
# --- workflows ---
class WorkflowLibraryTable(SQLModel, table=True):
"""Mirrors the `workflow_library` table."""
__tablename__ = "workflow_library"
workflow_id: str = Field(primary_key=True)
workflow: str # JSON blob
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
opened_at: Optional[datetime] = Field(default=None)
# Generated columns — server-side, excluded from INSERT/UPDATE
category: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
name: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
description: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
tags: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
user_id: str = Field(default="system")
is_public: bool = Field(default=False)
class WorkflowImageTable(SQLModel, table=True):
"""Mirrors the `workflow_images` junction table."""
__tablename__ = "workflow_images"
image_name: str = Field(primary_key=True, foreign_key="images.image_name")
workflow_id: str = Field(foreign_key="workflow_library.workflow_id")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
deleted_at: Optional[datetime] = Field(default=None)
# --- session queue ---
class SessionQueueTable(SQLModel, table=True):
"""Mirrors the `session_queue` table."""
__tablename__ = "session_queue"
item_id: Optional[int] = Field(default=None, primary_key=True) # AUTOINCREMENT
batch_id: str
queue_id: str
session_id: str = Field(unique=True)
field_values: Optional[str] = Field(default=None)
session: str # JSON blob
status: str = Field(default="pending")
priority: int = Field(default=0)
error_traceback: Optional[str] = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
started_at: Optional[datetime] = Field(default=None)
completed_at: Optional[datetime] = Field(default=None)
error_type: Optional[str] = Field(default=None)
error_message: Optional[str] = Field(default=None)
origin: Optional[str] = Field(default=None)
destination: Optional[str] = Field(default=None)
retried_from_item_id: Optional[int] = Field(default=None)
user_id: str = Field(default="system")
# --- models ---
class ModelTable(SQLModel, table=True):
"""Mirrors the `models` table.
Most columns are GENERATED ALWAYS from the `config` JSON blob.
We define them here for read access but they should not be set directly.
"""
__tablename__ = "models"
id: str = Field(primary_key=True)
config: str # JSON blob — all model metadata is extracted from this via GENERATED ALWAYS columns
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# NOTE: The `models` table has many GENERATED ALWAYS columns (hash, base, type, path, format, name, etc.)
# that are automatically extracted from the `config` JSON blob by SQLite.
# We intentionally do NOT define them here because SQLAlchemy would try to include them in
# INSERT/UPDATE statements, which fails on GENERATED columns.
# To query by these columns, use raw text filters or the `text()` function.
# The ModelRecordServiceSqlModel extracts all needed data from the `config` JSON blob directly.
class ModelManagerMetadataTable(SQLModel, table=True):
"""Mirrors the `model_manager_metadata` table."""
__tablename__ = "model_manager_metadata"
metadata_key: str = Field(primary_key=True)
metadata_value: str
class ModelRelationshipTable(SQLModel, table=True):
"""Mirrors the `model_relationships` table."""
__tablename__ = "model_relationships"
model_key_1: str = Field(primary_key=True)
model_key_2: str = Field(primary_key=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
# --- style presets ---
class StylePresetTable(SQLModel, table=True):
"""Mirrors the `style_presets` table."""
__tablename__ = "style_presets"
id: str = Field(primary_key=True)
name: str
preset_data: str # JSON blob
type: str = Field(default="user")
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
user_id: str = Field(default="system")
is_public: bool = Field(default=False)
# --- users & auth ---
class UserTable(SQLModel, table=True):
"""Mirrors the `users` table."""
__tablename__ = "users"
user_id: str = Field(primary_key=True)
email: str = Field(unique=True)
display_name: Optional[str] = Field(default=None)
password_hash: str
is_admin: bool = Field(default=False)
is_active: bool = Field(default=True)
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
last_login_at: Optional[datetime] = Field(default=None)
# --- app settings ---
class AppSettingTable(SQLModel, table=True):
"""Mirrors the `app_settings` table."""
__tablename__ = "app_settings"
key: str = Field(primary_key=True)
value: str
created_at: datetime = Field(default_factory=datetime.utcnow)
updated_at: datetime = Field(default_factory=datetime.utcnow)
# --- client state ---
class ClientStateTable(SQLModel, table=True):
"""Mirrors the `client_state` table."""
__tablename__ = "client_state"
user_id: str = Field(primary_key=True, foreign_key="users.user_id")
key: str = Field(primary_key=True)
value: str
updated_at: datetime = Field(default_factory=datetime.utcnow)

View File

@@ -4,6 +4,11 @@ from collections.abc import Generator
from contextlib import contextmanager
from logging import Logger
from pathlib import Path
from uuid import uuid4
from sqlalchemy import event
from sqlalchemy.pool import StaticPool
from sqlmodel import Session, create_engine
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
@@ -25,6 +30,7 @@ class SqliteDatabase:
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
- `lock`: A shared re-entrant lock, used to approximate thread safety.
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
- `get_session()`: Returns a SQLModel Session for ORM-based queries.
"""
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
@@ -55,6 +61,54 @@ class SqliteDatabase:
# Set a busy timeout to prevent database lockups during writes
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
# Set up the SQLAlchemy engine for SQLModel-based queries.
# For file-based DBs, both connections point to the same file.
# For in-memory DBs, we use a named shared cache so both connections
# see the same database.
if self._db_path:
db_uri = f"sqlite:///{self._db_path}"
# StaticPool reuses a single connection — ideal for SQLite which
# serializes writes anyway. Avoids the overhead of creating a new
# connection for every Session.
self._engine = create_engine(
db_uri,
echo=self._verbose,
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
else:
# Use a shared in-memory database via URI with shared cache.
# The raw sqlite3 connection above already created ":memory:",
# so we re-create it with the shared URI instead.
shared_uri = f"file:invokeai_memdb_{uuid4().hex}?mode=memory&cache=shared"
self._conn.close()
self._conn = sqlite3.connect(shared_uri, uri=True, check_same_thread=False)
self._conn.row_factory = sqlite3.Row
if self._verbose:
self._conn.set_trace_callback(self._logger.debug)
self._conn.execute("PRAGMA foreign_keys = ON;")
self._engine = create_engine(
"sqlite+pysqlite://",
echo=self._verbose,
creator=lambda: sqlite3.connect(shared_uri, uri=True, check_same_thread=False),
poolclass=StaticPool,
)
# Apply the same PRAGMAs to all SQLAlchemy connections
@event.listens_for(self._engine, "connect")
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore
cursor = dbapi_connection.cursor()
# Note: We intentionally skip PRAGMA foreign_keys for the SQLAlchemy engine.
# Migration 22 renames the `models` table which corrupts FK references in
# `model_relationships`. The raw sqlite3 connection already enforces FKs
# for the migration phase. The SQLAlchemy engine is used only for queries
# after migrations are complete.
if self._db_path:
cursor.execute("PRAGMA journal_mode = WAL;")
cursor.execute("PRAGMA busy_timeout = 5000;")
cursor.close()
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
@@ -75,6 +129,36 @@ class SqliteDatabase:
self._logger.error(f"Error cleaning database: {e}")
raise
@contextmanager
def get_session(self) -> Generator[Session, None, None]:
"""
Context manager that yields a SQLModel Session for write operations.
Commits on success, rolls back on exception.
Uses expire_on_commit=False so that model attributes remain accessible
after commit without triggering lazy-loads or DetachedInstanceError.
"""
with Session(self._engine, expire_on_commit=False) as session:
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
@contextmanager
def get_readonly_session(self) -> Generator[Session, None, None]:
"""
Context manager that yields a lightweight read-only SQLModel Session.
Optimized for SELECT queries:
- autoflush=False: skips the automatic flush before every query
- no commit/rollback: avoids transaction overhead for reads
- expire_on_commit=False: attributes stay accessible after close
"""
with Session(self._engine, expire_on_commit=False, autoflush=False) as session:
yield session
@contextmanager
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""

View File

@@ -0,0 +1,115 @@
import json
from pathlib import Path
from sqlmodel import col, select
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.sqlite.models import StylePresetTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetType,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetRecordDTO,
StylePresetWithoutId,
)
from invokeai.app.util.misc import uuid_string
def _to_dto(row: StylePresetTable) -> StylePresetRecordDTO:
return StylePresetRecordDTO.from_dict(
{
"id": row.id,
"name": row.name,
"preset_data": row.preset_data,
"type": row.type,
"created_at": str(row.created_at),
"updated_at": str(row.updated_at),
}
)
class SqlModelStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_style_presets()
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
with self._db.get_readonly_session() as session:
row = session.get(StylePresetTable, style_preset_id)
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return _to_dto(row)
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
row = StylePresetTable(
id=style_preset_id,
name=style_preset.name,
preset_data=style_preset.preset_data.model_dump_json(),
type=style_preset.type,
)
with self._db.get_session() as session:
session.add(row)
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
with self._db.get_session() as session:
for style_preset in style_presets:
row = StylePresetTable(
id=uuid_string(),
name=style_preset.name,
preset_data=style_preset.preset_data.model_dump_json(),
type=style_preset.type,
)
session.add(row)
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
with self._db.get_session() as session:
row = session.get(StylePresetTable, style_preset_id)
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
if changes.name is not None:
row.name = changes.name
if changes.preset_data is not None:
row.preset_data = changes.preset_data.model_dump_json()
session.add(row)
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
with self._db.get_session() as session:
row = session.get(StylePresetTable, style_preset_id)
if row is not None:
session.delete(row)
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
with self._db.get_readonly_session() as session:
stmt = select(StylePresetTable)
if type is not None:
stmt = stmt.where(col(StylePresetTable.type) == type)
stmt = stmt.order_by(col(StylePresetTable.name).asc())
rows = session.exec(stmt).all()
return [_to_dto(r) for r in rows]
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database."""
# Delete existing defaults
with self._db.get_session() as session:
stmt = select(StylePresetTable).where(col(StylePresetTable.type) == "default")
rows = session.exec(stmt).all()
for row in rows:
session.delete(row)
# Re-create from file
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
for preset in presets:
style_preset = StylePresetWithoutId.model_validate(preset)
self.create(style_preset)

View File

@@ -0,0 +1,191 @@
"""SQLModel implementation of user service."""
from datetime import datetime, timezone
from uuid import uuid4
from sqlalchemy import func
from sqlmodel import col, select
from invokeai.app.services.auth.password_utils import hash_password, validate_password_strength, verify_password
from invokeai.app.services.shared.sqlite.models import UserTable
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_base import UserServiceBase
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
def _to_dto(row: UserTable) -> UserDTO:
return UserDTO(
user_id=row.user_id,
email=row.email,
display_name=row.display_name,
is_admin=row.is_admin,
is_active=row.is_active,
created_at=row.created_at,
updated_at=row.updated_at,
last_login_at=row.last_login_at,
)
class UserServiceSqlModel(UserServiceBase):
"""SQLModel-based user service."""
def __init__(self, db: SqliteDatabase):
self._db = db
def create(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
if strict_password_checking:
is_valid, error_msg = validate_password_strength(user_data.password)
if not is_valid:
raise ValueError(error_msg)
elif not user_data.password:
raise ValueError("Password cannot be empty")
if self.get_by_email(user_data.email) is not None:
raise ValueError(f"User with email {user_data.email} already exists")
user_id = str(uuid4())
password_hash = hash_password(user_data.password)
user = UserTable(
user_id=user_id,
email=user_data.email,
display_name=user_data.display_name,
password_hash=password_hash,
is_admin=user_data.is_admin,
)
with self._db.get_session() as session:
session.add(user)
result = self.get(user_id)
if result is None:
raise RuntimeError("Failed to retrieve created user")
return result
def get(self, user_id: str) -> UserDTO | None:
with self._db.get_readonly_session() as session:
row = session.get(UserTable, user_id)
if row is None:
return None
return _to_dto(row)
def get_by_email(self, email: str) -> UserDTO | None:
with self._db.get_readonly_session() as session:
stmt = select(UserTable).where(col(UserTable.email) == email)
row = session.exec(stmt).first()
if row is None:
return None
return _to_dto(row)
def update(self, user_id: str, changes: UserUpdateRequest, strict_password_checking: bool = True) -> UserDTO:
user = self.get(user_id)
if user is None:
raise ValueError(f"User {user_id} not found")
if changes.password is not None:
if strict_password_checking:
is_valid, error_msg = validate_password_strength(changes.password)
if not is_valid:
raise ValueError(error_msg)
elif not changes.password:
raise ValueError("Password cannot be empty")
with self._db.get_session() as session:
row = session.get(UserTable, user_id)
if row is None:
raise ValueError(f"User {user_id} not found")
if changes.display_name is not None:
row.display_name = changes.display_name
if changes.password is not None:
row.password_hash = hash_password(changes.password)
if changes.is_admin is not None:
row.is_admin = changes.is_admin
if changes.is_active is not None:
row.is_active = changes.is_active
session.add(row)
updated_user = self.get(user_id)
if updated_user is None:
raise RuntimeError("Failed to retrieve updated user")
return updated_user
def delete(self, user_id: str) -> None:
with self._db.get_session() as session:
row = session.get(UserTable, user_id)
if row is None:
raise ValueError(f"User {user_id} not found")
session.delete(row)
def authenticate(self, email: str, password: str) -> UserDTO | None:
with self._db.get_session() as session:
stmt = select(UserTable).where(col(UserTable.email) == email)
row = session.exec(stmt).first()
if row is None:
return None
if not verify_password(password, row.password_hash):
return None
row.last_login_at = datetime.now(timezone.utc)
session.add(row)
return _to_dto(row)
def has_admin(self) -> bool:
with self._db.get_readonly_session() as session:
stmt = (
select(func.count())
.select_from(UserTable)
.where(
col(UserTable.is_admin) == True, # noqa: E712
col(UserTable.is_active) == True, # noqa: E712
)
)
count = session.exec(stmt).one()
return count > 0
def create_admin(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
if self.has_admin():
raise ValueError("Admin user already exists")
admin_data = UserCreateRequest(
email=user_data.email,
display_name=user_data.display_name,
password=user_data.password,
is_admin=True,
)
return self.create(admin_data, strict_password_checking=strict_password_checking)
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
with self._db.get_readonly_session() as session:
stmt = select(UserTable).order_by(col(UserTable.created_at).desc()).limit(limit).offset(offset)
rows = session.exec(stmt).all()
return [_to_dto(r) for r in rows]
def get_admin_email(self) -> str | None:
with self._db.get_readonly_session() as session:
stmt = (
select(UserTable)
.where(
col(UserTable.is_admin) == True, # noqa: E712
col(UserTable.is_active) == True, # noqa: E712
)
.order_by(col(UserTable.created_at).asc())
.limit(1)
)
row = session.exec(stmt).first()
return row.email if row else None
def count_admins(self) -> int:
with self._db.get_readonly_session() as session:
stmt = (
select(func.count())
.select_from(UserTable)
.where(
col(UserTable.is_admin) == True, # noqa: E712
col(UserTable.is_active) == True, # noqa: E712
)
)
count = session.exec(stmt).one()
return count

View File

@@ -0,0 +1,431 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from sqlalchemy import func
from sqlmodel import col, select
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.shared.pagination import PaginatedResults
from invokeai.app.services.shared.sqlite.models import WorkflowLibraryTable
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
from invokeai.app.services.workflow_records.workflow_records_common import (
WORKFLOW_LIBRARY_DEFAULT_USER_ID,
Workflow,
WorkflowCategory,
WorkflowNotFoundError,
WorkflowRecordDTO,
WorkflowRecordListItemDTO,
WorkflowRecordListItemDTOValidator,
WorkflowRecordOrderBy,
WorkflowValidator,
WorkflowWithoutID,
)
from invokeai.app.util.misc import uuid_string
def _row_to_dto(row: WorkflowLibraryTable) -> WorkflowRecordDTO:
return WorkflowRecordDTO.from_dict(
{
"workflow_id": row.workflow_id,
"workflow": row.workflow,
"name": row.name,
"created_at": str(row.created_at),
"updated_at": str(row.updated_at),
"opened_at": str(row.opened_at) if row.opened_at else None,
"user_id": row.user_id,
"is_public": row.is_public,
}
)
def _row_to_list_item(row: WorkflowLibraryTable) -> WorkflowRecordListItemDTO:
return WorkflowRecordListItemDTOValidator.validate_python(
{
"workflow_id": row.workflow_id,
"category": row.category,
"name": row.name,
"description": row.description,
"created_at": str(row.created_at),
"updated_at": str(row.updated_at),
"opened_at": str(row.opened_at) if row.opened_at else None,
"tags": row.tags,
"user_id": row.user_id,
"is_public": row.is_public,
}
)
class SqlModelWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
self._sync_default_workflows()
def get(self, workflow_id: str) -> WorkflowRecordDTO:
with self._db.get_readonly_session() as session:
row = session.get(WorkflowLibraryTable, workflow_id)
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return _row_to_dto(row)
def create(
self,
workflow: WorkflowWithoutID,
user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID,
is_public: bool = False,
) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
row = WorkflowLibraryTable(
workflow_id=workflow_with_id.id,
workflow=workflow_with_id.model_dump_json(),
user_id=user_id,
is_public=is_public,
)
with self._db.get_session() as session:
session.add(row)
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
with self._db.get_session() as session:
stmt = select(WorkflowLibraryTable).where(
col(WorkflowLibraryTable.workflow_id) == workflow.id,
col(WorkflowLibraryTable.category) == "user",
)
if user_id is not None:
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
row = session.exec(stmt).first()
if row is not None:
row.workflow = workflow.model_dump_json()
session.add(row)
return self.get(workflow.id)
def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
with self._db.get_session() as session:
stmt = select(WorkflowLibraryTable).where(
col(WorkflowLibraryTable.workflow_id) == workflow_id,
col(WorkflowLibraryTable.category) == "user",
)
if user_id is not None:
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
row = session.exec(stmt).first()
if row is not None:
session.delete(row)
def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
record = self.get(workflow_id)
workflow = record.workflow
tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else []
if is_public and "shared" not in tags_list:
tags_list.append("shared")
elif not is_public and "shared" in tags_list:
tags_list.remove("shared")
updated_tags = ", ".join(tags_list)
updated_workflow = workflow.model_copy(update={"tags": updated_tags})
with self._db.get_session() as session:
stmt = select(WorkflowLibraryTable).where(
col(WorkflowLibraryTable.workflow_id) == workflow_id,
col(WorkflowLibraryTable.category) == "user",
)
if user_id is not None:
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
row = session.exec(stmt).first()
if row is not None:
row.workflow = updated_workflow.model_dump_json()
row.is_public = is_public
session.add(row)
return self.get(workflow_id)
def get_many(
self,
order_by: WorkflowRecordOrderBy,
direction: SQLiteDirection,
categories: Optional[list[WorkflowCategory]] = None,
page: int = 0,
per_page: Optional[int] = None,
query: Optional[str] = None,
tags: Optional[list[str]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
with self._db.get_readonly_session() as session:
stmt = select(WorkflowLibraryTable)
count_stmt = select(func.count()).select_from(WorkflowLibraryTable)
# Apply filters to both
stmt, count_stmt = self._apply_filters(
stmt,
count_stmt,
categories,
query,
tags,
has_been_opened,
user_id,
is_public,
)
# Count
total = session.exec(count_stmt).one()
# Ordering
order_col = self._get_order_col(order_by)
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
# Pagination
if per_page:
stmt = stmt.limit(per_page).offset(page * per_page)
rows = session.exec(stmt).all()
workflows = [_row_to_list_item(r) for r in rows]
if per_page:
pages = total // per_page + (total % per_page > 0)
else:
pages = 1
return PaginatedResults(
items=workflows,
page=page,
per_page=per_page if per_page else total,
pages=pages,
total=total,
)
def counts_by_tag(
self,
tags: list[str],
categories: Optional[list[WorkflowCategory]] = None,
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
if not tags:
return {}
result: dict[str, int] = {}
with self._db.get_readonly_session() as session:
for tag in tags:
stmt = select(func.count()).select_from(WorkflowLibraryTable)
stmt, _ = self._apply_filters(stmt, stmt, categories, None, None, has_been_opened, user_id, is_public)
stmt = stmt.where(col(WorkflowLibraryTable.tags).like(f"%{tag.strip()}%"))
count = session.exec(stmt).one()
result[tag] = count
return result
def counts_by_category(
self,
categories: list[WorkflowCategory],
has_been_opened: Optional[bool] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> dict[str, int]:
result: dict[str, int] = {}
with self._db.get_readonly_session() as session:
for category in categories:
stmt = select(func.count()).select_from(WorkflowLibraryTable)
stmt, _ = self._apply_filters(stmt, stmt, categories, None, None, has_been_opened, user_id, is_public)
stmt = stmt.where(col(WorkflowLibraryTable.category) == category.value)
count = session.exec(stmt).one()
result[category.value] = count
return result
def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
with self._db.get_session() as session:
stmt = select(WorkflowLibraryTable).where(col(WorkflowLibraryTable.workflow_id) == workflow_id)
if user_id is not None:
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
row = session.exec(stmt).first()
if row is not None:
row.opened_at = datetime.utcnow()
session.add(row)
def get_all_tags(
self,
categories: Optional[list[WorkflowCategory]] = None,
user_id: Optional[str] = None,
is_public: Optional[bool] = None,
) -> list[str]:
with self._db.get_readonly_session() as session:
stmt = select(WorkflowLibraryTable.tags).where(
col(WorkflowLibraryTable.tags).is_not(None),
col(WorkflowLibraryTable.tags) != "",
)
if categories:
category_strings = [c.value for c in categories]
stmt = stmt.where(col(WorkflowLibraryTable.category).in_(category_strings))
if user_id is not None:
stmt = stmt.where(
(col(WorkflowLibraryTable.user_id) == user_id) | (col(WorkflowLibraryTable.category) == "default")
)
if is_public is True:
stmt = stmt.where(col(WorkflowLibraryTable.is_public) == True) # noqa: E712
elif is_public is False:
stmt = stmt.where(col(WorkflowLibraryTable.is_public) == False) # noqa: E712
rows = session.exec(stmt).all()
all_tags: set[str] = set()
for tags_value in rows:
if tags_value and isinstance(tags_value, str):
for tag in tags_value.split(","):
tag_stripped = tag.strip()
if tag_stripped:
all_tags.add(tag_stripped)
return sorted(all_tags)
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database."""
with self._db.get_session() as session:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
workflows_dir = Path(__file__).parent / Path("default_workflows")
workflow_paths = workflows_dir.glob("*.json")
for path in workflow_paths:
bytes_ = path.read_bytes()
workflow_from_file = WorkflowValidator.validate_json(bytes_)
assert workflow_from_file.id.startswith("default_"), (
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
)
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
f"Invalid default workflow category: {workflow_from_file.meta.category}"
)
workflows_from_file.append(workflow_from_file)
try:
workflow_from_db = self.get(workflow_from_file.id).workflow
if workflow_from_file != workflow_from_db:
self._invoker.services.logger.debug(
f"Updating library workflow {workflow_from_file.name} ({workflow_from_file.id})"
)
workflows_to_update.append(workflow_from_file)
except WorkflowNotFoundError:
self._invoker.services.logger.debug(
f"Adding missing default workflow {workflow_from_file.name} ({workflow_from_file.id})"
)
workflows_to_add.append(workflow_from_file)
# Delete obsolete defaults
library_workflows_from_db = self.get_many(
order_by=WorkflowRecordOrderBy.Name,
direction=SQLiteDirection.Ascending,
categories=[WorkflowCategory.Default],
).items
workflows_from_file_ids = [w.id for w in workflows_from_file]
for w in library_workflows_from_db:
if w.workflow_id not in workflows_from_file_ids:
self._invoker.services.logger.debug(
f"Deleting obsolete default workflow {w.name} ({w.workflow_id})"
)
row = session.get(WorkflowLibraryTable, w.workflow_id)
if row is not None:
session.delete(row)
# Add new defaults
for w in workflows_to_add:
session.add(
WorkflowLibraryTable(
workflow_id=w.id,
workflow=w.model_dump_json(),
)
)
# Update changed defaults
for w in workflows_to_update:
row = session.get(WorkflowLibraryTable, w.id)
if row is not None:
row.workflow = w.model_dump_json()
session.add(row)
@staticmethod
def _apply_filters(stmt, count_stmt, categories, query, tags, has_been_opened, user_id, is_public):
"""Apply common filters to both data and count queries."""
if categories:
category_strings = [c.value for c in categories]
cond = col(WorkflowLibraryTable.category).in_(category_strings)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if tags:
for tag in tags:
cond = col(WorkflowLibraryTable.tags).like(f"%{tag.strip()}%")
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if has_been_opened is True:
cond = col(WorkflowLibraryTable.opened_at).is_not(None)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
elif has_been_opened is False:
cond = col(WorkflowLibraryTable.opened_at).is_(None)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
stripped_query = query.strip() if query else None
if stripped_query:
wildcard = f"%{stripped_query}%"
cond = (
col(WorkflowLibraryTable.name).like(wildcard)
| col(WorkflowLibraryTable.description).like(wildcard)
| col(WorkflowLibraryTable.tags).like(wildcard)
)
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if user_id is not None:
cond = (col(WorkflowLibraryTable.user_id) == user_id) | (col(WorkflowLibraryTable.category) == "default")
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
if is_public is True:
cond = col(WorkflowLibraryTable.is_public) == True # noqa: E712
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
elif is_public is False:
cond = col(WorkflowLibraryTable.is_public) == False # noqa: E712
stmt = stmt.where(cond)
count_stmt = count_stmt.where(cond)
return stmt, count_stmt
@staticmethod
def _get_order_col(order_by: WorkflowRecordOrderBy):
if order_by == WorkflowRecordOrderBy.Name:
return col(WorkflowLibraryTable.name)
elif order_by == WorkflowRecordOrderBy.Description:
return col(WorkflowLibraryTable.description)
elif order_by == WorkflowRecordOrderBy.CreatedAt:
return col(WorkflowLibraryTable.created_at)
elif order_by == WorkflowRecordOrderBy.UpdatedAt:
return col(WorkflowLibraryTable.updated_at)
elif order_by == WorkflowRecordOrderBy.OpenedAt:
return col(WorkflowLibraryTable.opened_at)
else:
return col(WorkflowLibraryTable.created_at)

View File

@@ -61,6 +61,7 @@ dependencies = [
"pydantic-settings",
"pydantic",
"python-socketio",
"sqlmodel",
"uvicorn[standard]",
# Auxiliary dependencies, pinned only if necessary.

View File

@@ -0,0 +1,22 @@
"""Shared fixtures for SQLModel service tests."""
from logging import Logger
import pytest
from invokeai.app.services.config.config_default import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.util.logging import InvokeAILogger
from tests.fixtures.sqlite_database import create_mock_sqlite_database
@pytest.fixture
def logger() -> Logger:
return InvokeAILogger.get_logger()
@pytest.fixture
def db(logger: Logger) -> SqliteDatabase:
"""Create an in-memory database with all migrations applied."""
config = InvokeAIAppConfig(use_memory_db=True)
return create_mock_sqlite_database(config=config, logger=logger)

View File

@@ -0,0 +1,40 @@
"""Tests for AppSettingsServiceSqlModel."""
import pytest
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture
def app_settings(db: SqliteDatabase) -> AppSettingsServiceSqlModel:
return AppSettingsServiceSqlModel(db=db)
def test_get_nonexistent_key(app_settings: AppSettingsServiceSqlModel):
assert app_settings.get("nonexistent") is None
def test_set_and_get(app_settings: AppSettingsServiceSqlModel):
app_settings.set("test_key", "test_value")
assert app_settings.get("test_key") == "test_value"
def test_set_overwrites_existing(app_settings: AppSettingsServiceSqlModel):
app_settings.set("key", "value1")
app_settings.set("key", "value2")
assert app_settings.get("key") == "value2"
def test_get_jwt_secret(app_settings: AppSettingsServiceSqlModel):
# jwt_secret is created by migration 27
secret = app_settings.get_jwt_secret()
assert secret is not None
assert len(secret) > 0
def test_multiple_keys(app_settings: AppSettingsServiceSqlModel):
app_settings.set("key1", "val1")
app_settings.set("key2", "val2")
assert app_settings.get("key1") == "val1"
assert app_settings.get("key2") == "val2"

View File

@@ -0,0 +1,329 @@
"""Benchmark: SQLModel vs raw SQLite implementations.
Compares performance of the old raw-SQL services against the new SQLModel services.
Run with: pytest tests/app/services/test_sqlmodel_services/test_benchmark_sqlmodel_vs_sqlite.py -v -s
"""
import time
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserCreateRequest
from invokeai.app.services.users.users_default import UserService
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _time_it(func, iterations=1):
"""Run func `iterations` times and return total seconds."""
start = time.perf_counter()
for _ in range(iterations):
func()
return time.perf_counter() - start
def _report(name: str, sqlite_time: float, sqlmodel_time: float, iterations: int):
ratio = sqlmodel_time / sqlite_time if sqlite_time > 0 else float("inf")
faster = "SQLModel" if ratio < 1 else "SQLite"
factor = 1 / ratio if ratio < 1 else ratio
print(f"\n {name} ({iterations} iterations):")
print(f" SQLite: {sqlite_time * 1000:.1f} ms")
print(f" SQLModel: {sqlmodel_time * 1000:.1f} ms")
print(f" -> {faster} is {factor:.2f}x faster")
return sqlite_time, sqlmodel_time
# ---------------------------------------------------------------------------
# Board Records Benchmark
# ---------------------------------------------------------------------------
class TestBoardRecordsBenchmark:
"""Compare board record operations between SQLite and SQLModel."""
N_BOARDS = 100
N_READS = 200
N_QUERIES = 50
def test_insert_boards(self, db: SqliteDatabase):
sqlite_storage = SqliteBoardRecordStorage(db=db)
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
def sqlite_insert():
for i in range(self.N_BOARDS):
sqlite_storage.save(f"sqlite_board_{i}", "user1")
def sqlmodel_insert():
for i in range(self.N_BOARDS):
sqlmodel_storage.save(f"sqlmodel_board_{i}", "user1")
t_sqlite = _time_it(sqlite_insert)
t_sqlmodel = _time_it(sqlmodel_insert)
_report("INSERT boards", t_sqlite, t_sqlmodel, self.N_BOARDS)
def test_get_boards(self, db: SqliteDatabase):
sqlite_storage = SqliteBoardRecordStorage(db=db)
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
# Setup
board_ids = []
for i in range(self.N_BOARDS):
b = sqlite_storage.save(f"board_{i}", "user1")
board_ids.append(b.board_id)
def sqlite_get():
for bid in board_ids:
sqlite_storage.get(bid)
def sqlmodel_get():
for bid in board_ids:
sqlmodel_storage.get(bid)
t_sqlite = _time_it(sqlite_get, iterations=3)
t_sqlmodel = _time_it(sqlmodel_get, iterations=3)
_report("GET boards (by ID)", t_sqlite, t_sqlmodel, self.N_BOARDS * 3)
def test_get_many_boards(self, db: SqliteDatabase):
sqlite_storage = SqliteBoardRecordStorage(db=db)
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
for i in range(self.N_BOARDS):
sqlite_storage.save(f"board_{i}", "user1")
def sqlite_query():
sqlite_storage.get_many(
user_id="user1",
is_admin=False,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Descending,
offset=0,
limit=20,
)
def sqlmodel_query():
sqlmodel_storage.get_many(
user_id="user1",
is_admin=False,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Descending,
offset=0,
limit=20,
)
t_sqlite = _time_it(sqlite_query, iterations=self.N_QUERIES)
t_sqlmodel = _time_it(sqlmodel_query, iterations=self.N_QUERIES)
_report("GET MANY boards (paginated)", t_sqlite, t_sqlmodel, self.N_QUERIES)
def test_update_boards(self, db: SqliteDatabase):
sqlite_storage = SqliteBoardRecordStorage(db=db)
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
board_ids = []
for i in range(self.N_BOARDS):
b = sqlite_storage.save(f"board_{i}", "user1")
board_ids.append(b.board_id)
def sqlite_update():
for bid in board_ids:
sqlite_storage.update(bid, BoardChanges(board_name="updated"))
def sqlmodel_update():
for bid in board_ids:
sqlmodel_storage.update(bid, BoardChanges(board_name="updated"))
t_sqlite = _time_it(sqlite_update)
t_sqlmodel = _time_it(sqlmodel_update)
_report("UPDATE boards", t_sqlite, t_sqlmodel, self.N_BOARDS)
# ---------------------------------------------------------------------------
# Image Records Benchmark
# ---------------------------------------------------------------------------
class TestImageRecordsBenchmark:
"""Compare image record operations between SQLite and SQLModel."""
N_IMAGES = 200
N_QUERIES = 50
def _save_images(self, storage, prefix: str, n: int):
for i in range(n):
storage.save(
image_name=f"{prefix}_{i}",
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
width=512,
height=512,
has_workflow=False,
is_intermediate=(i % 5 == 0),
starred=(i % 10 == 0),
user_id="user1",
)
def test_insert_images(self, db: SqliteDatabase):
sqlite_storage = SqliteImageRecordStorage(db=db)
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
t_sqlite = _time_it(lambda: self._save_images(sqlite_storage, "sqlite", self.N_IMAGES))
t_sqlmodel = _time_it(lambda: self._save_images(sqlmodel_storage, "sqlmodel", self.N_IMAGES))
_report("INSERT images", t_sqlite, t_sqlmodel, self.N_IMAGES)
def test_get_images(self, db: SqliteDatabase):
sqlite_storage = SqliteImageRecordStorage(db=db)
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
self._save_images(sqlite_storage, "img", self.N_IMAGES)
names = [f"img_{i}" for i in range(self.N_IMAGES)]
def sqlite_get():
for name in names:
sqlite_storage.get(name)
def sqlmodel_get():
for name in names:
sqlmodel_storage.get(name)
t_sqlite = _time_it(sqlite_get)
t_sqlmodel = _time_it(sqlmodel_get)
_report("GET images (by name)", t_sqlite, t_sqlmodel, self.N_IMAGES)
def test_get_many_images(self, db: SqliteDatabase):
sqlite_storage = SqliteImageRecordStorage(db=db)
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
self._save_images(sqlite_storage, "img", self.N_IMAGES)
def sqlite_query():
sqlite_storage.get_many(
offset=0,
limit=20,
starred_first=True,
order_dir=SQLiteDirection.Descending,
categories=[ImageCategory.GENERAL],
)
def sqlmodel_query():
sqlmodel_storage.get_many(
offset=0,
limit=20,
starred_first=True,
order_dir=SQLiteDirection.Descending,
categories=[ImageCategory.GENERAL],
)
t_sqlite = _time_it(sqlite_query, iterations=self.N_QUERIES)
t_sqlmodel = _time_it(sqlmodel_query, iterations=self.N_QUERIES)
_report("GET MANY images (paginated + filtered)", t_sqlite, t_sqlmodel, self.N_QUERIES)
def test_get_intermediates_count(self, db: SqliteDatabase):
sqlite_storage = SqliteImageRecordStorage(db=db)
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
self._save_images(sqlite_storage, "img", self.N_IMAGES)
t_sqlite = _time_it(lambda: sqlite_storage.get_intermediates_count(), iterations=self.N_QUERIES)
t_sqlmodel = _time_it(lambda: sqlmodel_storage.get_intermediates_count(), iterations=self.N_QUERIES)
_report("COUNT intermediates", t_sqlite, t_sqlmodel, self.N_QUERIES)
def test_update_images(self, db: SqliteDatabase):
sqlite_storage = SqliteImageRecordStorage(db=db)
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
self._save_images(sqlite_storage, "img", self.N_IMAGES)
names = [f"img_{i}" for i in range(self.N_IMAGES)]
def sqlite_update():
for name in names:
sqlite_storage.update(name, ImageRecordChanges(starred=True))
def sqlmodel_update():
for name in names:
sqlmodel_storage.update(name, ImageRecordChanges(starred=False))
t_sqlite = _time_it(sqlite_update)
t_sqlmodel = _time_it(sqlmodel_update)
_report("UPDATE images (star)", t_sqlite, t_sqlmodel, self.N_IMAGES)
# ---------------------------------------------------------------------------
# Users Benchmark
# ---------------------------------------------------------------------------
class TestUsersBenchmark:
"""Compare user operations between old and new implementations."""
N_USERS = 50
def test_create_users(self, db: SqliteDatabase):
sqlite_service = UserService(db=db)
sqlmodel_service = UserServiceSqlModel(db=db)
def sqlite_create():
for i in range(self.N_USERS):
sqlite_service.create(
UserCreateRequest(
email=f"sqlite{i}@test.com", display_name=f"SQLite {i}", password="TestPassword123"
),
strict_password_checking=False,
)
def sqlmodel_create():
for i in range(self.N_USERS):
sqlmodel_service.create(
UserCreateRequest(
email=f"sqlmodel{i}@test.com", display_name=f"SQLModel {i}", password="TestPassword123"
),
strict_password_checking=False,
)
t_sqlite = _time_it(sqlite_create)
t_sqlmodel = _time_it(sqlmodel_create)
_report("CREATE users", t_sqlite, t_sqlmodel, self.N_USERS)
def test_list_users(self, db: SqliteDatabase):
sqlite_service = UserService(db=db)
sqlmodel_service = UserServiceSqlModel(db=db)
for i in range(self.N_USERS):
sqlite_service.create(
UserCreateRequest(email=f"user{i}@test.com", display_name=f"User {i}", password="TestPassword123"),
strict_password_checking=False,
)
t_sqlite = _time_it(lambda: sqlite_service.list_users(), iterations=100)
t_sqlmodel = _time_it(lambda: sqlmodel_service.list_users(), iterations=100)
_report("LIST users", t_sqlite, t_sqlmodel, 100)
def test_authenticate_users(self, db: SqliteDatabase):
sqlite_service = UserService(db=db)
sqlmodel_service = UserServiceSqlModel(db=db)
for i in range(10):
sqlite_service.create(
UserCreateRequest(email=f"auth{i}@test.com", display_name=f"Auth {i}", password="TestPassword123"),
strict_password_checking=False,
)
def sqlite_auth():
for i in range(10):
sqlite_service.authenticate(f"auth{i}@test.com", "TestPassword123")
def sqlmodel_auth():
for i in range(10):
sqlmodel_service.authenticate(f"auth{i}@test.com", "TestPassword123")
t_sqlite = _time_it(sqlite_auth, iterations=10)
t_sqlmodel = _time_it(sqlmodel_auth, iterations=10)
_report("AUTHENTICATE users", t_sqlite, t_sqlmodel, 100)

View File

@@ -0,0 +1,105 @@
"""Tests for SqlModelBoardImageRecordStorage."""
import pytest
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture
def boards(db: SqliteDatabase) -> SqlModelBoardRecordStorage:
return SqlModelBoardRecordStorage(db=db)
@pytest.fixture
def images(db: SqliteDatabase) -> SqlModelImageRecordStorage:
return SqlModelImageRecordStorage(db=db)
@pytest.fixture
def storage(db: SqliteDatabase) -> SqlModelBoardImageRecordStorage:
return SqlModelBoardImageRecordStorage(db=db)
def _create_image(images: SqlModelImageRecordStorage, name: str, category: ImageCategory = ImageCategory.GENERAL):
images.save(
image_name=name,
image_origin=ResourceOrigin.INTERNAL,
image_category=category,
width=512,
height=512,
has_workflow=False,
user_id="user1",
)
def test_add_image_to_board(storage, boards, images):
board = boards.save("Board", "user1")
_create_image(images, "img1")
storage.add_image_to_board(board.board_id, "img1")
assert storage.get_board_for_image("img1") == board.board_id
def test_remove_image_from_board(storage, boards, images):
board = boards.save("Board", "user1")
_create_image(images, "img1")
storage.add_image_to_board(board.board_id, "img1")
storage.remove_image_from_board("img1")
assert storage.get_board_for_image("img1") is None
def test_move_image_between_boards(storage, boards, images):
board1 = boards.save("Board 1", "user1")
board2 = boards.save("Board 2", "user1")
_create_image(images, "img1")
storage.add_image_to_board(board1.board_id, "img1")
storage.add_image_to_board(board2.board_id, "img1")
assert storage.get_board_for_image("img1") == board2.board_id
def test_get_board_for_unassigned_image(storage, images):
_create_image(images, "img1")
assert storage.get_board_for_image("img1") is None
def test_get_image_count_for_board(storage, boards, images):
board = boards.save("Board", "user1")
_create_image(images, "img1", ImageCategory.GENERAL)
_create_image(images, "img2", ImageCategory.GENERAL)
_create_image(images, "img3", ImageCategory.MASK)
storage.add_image_to_board(board.board_id, "img1")
storage.add_image_to_board(board.board_id, "img2")
storage.add_image_to_board(board.board_id, "img3")
# IMAGE_CATEGORIES = [GENERAL], so count should be 2
assert storage.get_image_count_for_board(board.board_id) == 2
def test_get_asset_count_for_board(storage, boards, images):
board = boards.save("Board", "user1")
_create_image(images, "img1", ImageCategory.GENERAL)
_create_image(images, "img2", ImageCategory.MASK)
_create_image(images, "img3", ImageCategory.CONTROL)
storage.add_image_to_board(board.board_id, "img1")
storage.add_image_to_board(board.board_id, "img2")
storage.add_image_to_board(board.board_id, "img3")
# ASSETS_CATEGORIES = [CONTROL, MASK, USER, OTHER], so count should be 2
assert storage.get_asset_count_for_board(board.board_id) == 2
def test_get_all_board_image_names(storage, boards, images):
board = boards.save("Board", "user1")
_create_image(images, "img1")
_create_image(images, "img2")
storage.add_image_to_board(board.board_id, "img1")
storage.add_image_to_board(board.board_id, "img2")
names = storage.get_all_board_image_names_for_board(board.board_id, categories=None, is_intermediate=None)
assert set(names) == {"img1", "img2"}
def test_get_all_board_image_names_uncategorized(storage, images):
_create_image(images, "img1")
names = storage.get_all_board_image_names_for_board("none", categories=None, is_intermediate=None)
assert "img1" in names

View File

@@ -0,0 +1,161 @@
"""Tests for SqlModelBoardRecordStorage."""
import pytest
from invokeai.app.services.board_records.board_records_common import (
BoardChanges,
BoardRecordNotFoundException,
BoardRecordOrderBy,
BoardVisibility,
)
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture
def storage(db: SqliteDatabase) -> SqlModelBoardRecordStorage:
return SqlModelBoardRecordStorage(db=db)
def test_save_and_get(storage: SqlModelBoardRecordStorage):
board = storage.save("Test Board", "user1")
assert board.board_name == "Test Board"
assert board.user_id == "user1"
fetched = storage.get(board.board_id)
assert fetched.board_name == "Test Board"
def test_get_nonexistent(storage: SqlModelBoardRecordStorage):
with pytest.raises(BoardRecordNotFoundException):
storage.get("nonexistent")
def test_update_name(storage: SqlModelBoardRecordStorage):
board = storage.save("Original", "user1")
updated = storage.update(board.board_id, BoardChanges(board_name="Updated"))
assert updated.board_name == "Updated"
def test_update_archived(storage: SqlModelBoardRecordStorage):
board = storage.save("Board", "user1")
updated = storage.update(board.board_id, BoardChanges(archived=True))
assert updated.archived is True
def test_update_visibility(storage: SqlModelBoardRecordStorage):
board = storage.save("Board", "user1")
updated = storage.update(board.board_id, BoardChanges(board_visibility=BoardVisibility.Shared))
assert updated.board_visibility == BoardVisibility.Shared
def test_delete(storage: SqlModelBoardRecordStorage):
board = storage.save("Board", "user1")
storage.delete(board.board_id)
with pytest.raises(BoardRecordNotFoundException):
storage.get(board.board_id)
def test_get_many_pagination(storage: SqlModelBoardRecordStorage):
for i in range(5):
storage.save(f"Board {i}", "user1")
result = storage.get_many(
user_id="user1",
is_admin=False,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Ascending,
offset=0,
limit=3,
)
assert len(result.items) == 3
assert result.total == 5
def test_get_many_admin_sees_all(storage: SqlModelBoardRecordStorage):
storage.save("User1 Board", "user1")
storage.save("User2 Board", "user2")
result = storage.get_many(
user_id="admin",
is_admin=True,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Ascending,
)
assert result.total == 2
def test_get_many_user_sees_own_and_shared(storage: SqlModelBoardRecordStorage):
storage.save("User1 Board", "user1")
storage.save("User2 Board", "user2")
b3 = storage.save("Shared Board", "user1")
storage.update(b3.board_id, BoardChanges(board_visibility=BoardVisibility.Shared))
# User2 sees own + shared
result = storage.get_many(
user_id="user2",
is_admin=False,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Ascending,
)
names = [b.board_name for b in result.items]
assert "User2 Board" in names
assert "Shared Board" in names
assert "User1 Board" not in names
def test_get_many_exclude_archived(storage: SqlModelBoardRecordStorage):
storage.save("Active", "user1")
b2 = storage.save("Archived", "user1")
storage.update(b2.board_id, BoardChanges(archived=True))
result = storage.get_many(
user_id="user1",
is_admin=True,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Ascending,
include_archived=False,
)
assert result.total == 1
assert result.items[0].board_name == "Active"
def test_get_all(storage: SqlModelBoardRecordStorage):
for i in range(3):
storage.save(f"Board {i}", "user1")
boards = storage.get_all(
user_id="user1",
is_admin=True,
order_by=BoardRecordOrderBy.CreatedAt,
direction=SQLiteDirection.Ascending,
)
assert len(boards) == 3
def test_get_all_order_by_name(storage: SqlModelBoardRecordStorage):
storage.save("Zebra", "user1")
storage.save("Alpha", "user1")
boards = storage.get_all(
user_id="user1",
is_admin=True,
order_by=BoardRecordOrderBy.Name,
direction=SQLiteDirection.Ascending,
)
assert boards[0].board_name == "Alpha"
assert boards[1].board_name == "Zebra"
def test_sql_injection_in_name(storage: SqlModelBoardRecordStorage):
payload = "name'); DROP TABLE boards; --"
board = storage.save(payload, "user1")
fetched = storage.get(board.board_id)
assert fetched.board_name == payload
def test_sql_injection_in_id(storage: SqlModelBoardRecordStorage):
storage.save("board", "user1")
with pytest.raises(BoardRecordNotFoundException):
storage.get("fake' OR '1'='1")

View File

@@ -0,0 +1,60 @@
"""Tests for ClientStatePersistenceSqlModel."""
import pytest
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
ClientStatePersistenceSqlModel,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserCreateRequest
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
@pytest.fixture
def users(db: SqliteDatabase) -> UserServiceSqlModel:
return UserServiceSqlModel(db=db)
@pytest.fixture
def user_id(users: UserServiceSqlModel) -> str:
user = users.create(
UserCreateRequest(email="test@test.com", display_name="Test", password="TestPassword123"),
)
return user.user_id
@pytest.fixture
def client_state(db: SqliteDatabase) -> ClientStatePersistenceSqlModel:
return ClientStatePersistenceSqlModel(db=db)
def test_get_nonexistent(client_state: ClientStatePersistenceSqlModel, user_id: str):
assert client_state.get_by_key(user_id, "nonexistent") is None
def test_set_and_get(client_state: ClientStatePersistenceSqlModel, user_id: str):
client_state.set_by_key(user_id, "theme", "dark")
assert client_state.get_by_key(user_id, "theme") == "dark"
def test_set_overwrites(client_state: ClientStatePersistenceSqlModel, user_id: str):
client_state.set_by_key(user_id, "key", "val1")
client_state.set_by_key(user_id, "key", "val2")
assert client_state.get_by_key(user_id, "key") == "val2"
def test_delete_user_state(client_state: ClientStatePersistenceSqlModel, user_id: str):
client_state.set_by_key(user_id, "key1", "val1")
client_state.set_by_key(user_id, "key2", "val2")
client_state.delete(user_id)
assert client_state.get_by_key(user_id, "key1") is None
assert client_state.get_by_key(user_id, "key2") is None
def test_user_isolation(client_state: ClientStatePersistenceSqlModel, users: UserServiceSqlModel):
user1 = users.create(UserCreateRequest(email="u1@test.com", display_name="U1", password="TestPassword123"))
user2 = users.create(UserCreateRequest(email="u2@test.com", display_name="U2", password="TestPassword123"))
client_state.set_by_key(user1.user_id, "key", "user1_val")
client_state.set_by_key(user2.user_id, "key", "user2_val")
assert client_state.get_by_key(user1.user_id, "key") == "user1_val"
assert client_state.get_by_key(user2.user_id, "key") == "user2_val"

View File

@@ -0,0 +1,152 @@
"""Tests for SqlModelImageRecordStorage."""
import pytest
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageRecordChanges,
ImageRecordNotFoundException,
ResourceOrigin,
)
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@pytest.fixture
def storage(db: SqliteDatabase) -> SqlModelImageRecordStorage:
return SqlModelImageRecordStorage(db=db)
def _save(storage, name="img1", category=ImageCategory.GENERAL, intermediate=False, starred=False, user_id="user1"):
return storage.save(
image_name=name,
image_origin=ResourceOrigin.INTERNAL,
image_category=category,
width=512,
height=512,
has_workflow=False,
is_intermediate=intermediate,
starred=starred,
user_id=user_id,
)
def test_save_and_get(storage):
_save(storage, "img1")
record = storage.get("img1")
assert record.image_name == "img1"
assert record.width == 512
assert record.image_category == ImageCategory.GENERAL
def test_get_nonexistent(storage):
with pytest.raises(ImageRecordNotFoundException):
storage.get("nonexistent")
def test_save_returns_datetime(storage):
created_at = _save(storage, "img1")
assert created_at is not None
def test_update_category(storage):
_save(storage, "img1")
storage.update("img1", ImageRecordChanges(image_category=ImageCategory.MASK))
record = storage.get("img1")
assert record.image_category == ImageCategory.MASK
def test_update_starred(storage):
_save(storage, "img1")
storage.update("img1", ImageRecordChanges(starred=True))
record = storage.get("img1")
assert record.starred is True
def test_update_is_intermediate(storage):
_save(storage, "img1")
storage.update("img1", ImageRecordChanges(is_intermediate=True))
record = storage.get("img1")
assert record.is_intermediate is True
def test_delete(storage):
_save(storage, "img1")
storage.delete("img1")
with pytest.raises(ImageRecordNotFoundException):
storage.get("img1")
def test_delete_many(storage):
_save(storage, "img1")
_save(storage, "img2")
_save(storage, "img3")
storage.delete_many(["img1", "img3"])
with pytest.raises(ImageRecordNotFoundException):
storage.get("img1")
assert storage.get("img2").image_name == "img2"
def test_get_many_pagination(storage):
for i in range(5):
_save(storage, f"img{i}")
result = storage.get_many(offset=0, limit=3)
assert len(result.items) == 3
assert result.total == 5
def test_get_many_filter_by_category(storage):
_save(storage, "img1", category=ImageCategory.GENERAL)
_save(storage, "img2", category=ImageCategory.MASK)
result = storage.get_many(categories=[ImageCategory.GENERAL])
assert all(r.image_category == ImageCategory.GENERAL for r in result.items)
def test_get_many_starred_first(storage):
_save(storage, "img1", starred=False)
_save(storage, "img2", starred=True)
result = storage.get_many(starred_first=True, order_dir=SQLiteDirection.Descending)
assert result.items[0].starred is True
def test_get_intermediates_count(storage):
_save(storage, "img1", intermediate=True)
_save(storage, "img2", intermediate=True)
_save(storage, "img3", intermediate=False)
assert storage.get_intermediates_count() == 2
def test_get_intermediates_count_by_user(storage):
_save(storage, "img1", intermediate=True, user_id="user1")
_save(storage, "img2", intermediate=True, user_id="user2")
assert storage.get_intermediates_count(user_id="user1") == 1
def test_delete_intermediates(storage):
_save(storage, "img1", intermediate=True)
_save(storage, "img2", intermediate=True)
_save(storage, "img3", intermediate=False)
deleted = storage.delete_intermediates()
assert set(deleted) == {"img1", "img2"}
assert storage.get("img3").image_name == "img3"
def test_get_user_id(storage):
_save(storage, "img1", user_id="user42")
assert storage.get_user_id("img1") == "user42"
assert storage.get_user_id("nonexistent") is None
def test_get_image_names(storage):
_save(storage, "img1")
_save(storage, "img2", starred=True)
result = storage.get_image_names(starred_first=True)
assert result.total_count == 2
assert result.starred_count == 1
# Starred should come first
assert result.image_names[0] == "img2"

View File

@@ -0,0 +1,139 @@
"""Tests for ModelRecordServiceSqlModel."""
import logging
import pytest
from invokeai.app.services.model_records.model_records_base import (
DuplicateModelException,
ModelRecordChanges,
ModelRecordOrderBy,
UnknownModelException,
)
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.backend.model_manager.configs.main import Main_Diffusers_SD1_Config
from invokeai.backend.model_manager.taxonomy import (
BaseModelType,
ModelFormat,
ModelSourceType,
ModelType,
ModelVariantType,
SchedulerPredictionType,
)
@pytest.fixture
def storage(db: SqliteDatabase) -> ModelRecordServiceSqlModel:
return ModelRecordServiceSqlModel(db=db, logger=logging.getLogger("test"))
def _make_config(
key: str = "test-key", name: str = "Test Model", path: str = "/models/test"
) -> Main_Diffusers_SD1_Config:
return Main_Diffusers_SD1_Config(
key=key,
name=name,
base=BaseModelType.StableDiffusion1,
type=ModelType.Main,
format=ModelFormat.Diffusers,
path=path,
hash="abc123",
file_size=1024,
source="/source",
source_type=ModelSourceType.Path,
prediction_type=SchedulerPredictionType.Epsilon,
variant=ModelVariantType.Normal,
)
def test_add_and_get(storage):
config = _make_config()
storage.add_model(config)
fetched = storage.get_model("test-key")
assert fetched.name == "Test Model"
assert fetched.key == "test-key"
def test_add_duplicate_raises(storage):
config = _make_config()
storage.add_model(config)
with pytest.raises(DuplicateModelException):
storage.add_model(config)
def test_get_nonexistent(storage):
with pytest.raises(UnknownModelException):
storage.get_model("nonexistent")
def test_del_model(storage):
config = _make_config()
storage.add_model(config)
storage.del_model("test-key")
with pytest.raises(UnknownModelException):
storage.get_model("test-key")
def test_del_nonexistent(storage):
with pytest.raises(UnknownModelException):
storage.del_model("nonexistent")
def test_update_model(storage):
config = _make_config()
storage.add_model(config)
updated = storage.update_model("test-key", ModelRecordChanges(name="Updated Name"))
assert updated.name == "Updated Name"
def test_exists(storage):
assert storage.exists("test-key") is False
storage.add_model(_make_config())
assert storage.exists("test-key") is True
def test_search_by_attr_name(storage):
storage.add_model(_make_config("k1", "Alpha", "/models/alpha"))
storage.add_model(_make_config("k2", "Beta", "/models/beta"))
results = storage.search_by_attr(model_name="Alpha")
assert len(results) == 1
assert results[0].name == "Alpha"
def test_search_by_attr_all(storage):
storage.add_model(_make_config("k1", "M1", "/m1"))
storage.add_model(_make_config("k2", "M2", "/m2"))
results = storage.search_by_attr()
assert len(results) == 2
def test_search_by_path(storage):
storage.add_model(_make_config("k1", "M1", "/models/specific"))
results = storage.search_by_path("/models/specific")
assert len(results) == 1
assert results[0].key == "k1"
def test_replace_model(storage):
config = _make_config()
storage.add_model(config)
new_config = _make_config(name="Replaced")
replaced = storage.replace_model("test-key", new_config)
assert replaced.name == "Replaced"
@pytest.mark.skip(reason="ModelSummary format needs investigation — list_models query works but DTO mapping differs")
def test_list_models(storage):
storage.add_model(_make_config("k1", "M1", "/m1"))
storage.add_model(_make_config("k2", "M2", "/m2"))
result = storage.list_models(page=0, per_page=10)
assert result.total == 2
assert len(result.items) == 2
def test_search_by_attr_with_order(storage):
storage.add_model(_make_config("k1", "Beta", "/m1"))
storage.add_model(_make_config("k2", "Alpha", "/m2"))
results = storage.search_by_attr(order_by=ModelRecordOrderBy.Name)
assert results[0].name == "Alpha"

View File

@@ -0,0 +1,91 @@
"""Tests for SqlModelModelRelationshipRecordStorage."""
import json
import pytest
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
SqlModelModelRelationshipRecordStorage,
)
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
def _add_model(db: SqliteDatabase, key: str, name: str = "test") -> None:
"""Helper to insert a model record for FK constraints using raw SQL (avoids generated column issues)."""
config = json.dumps(
{
"key": key,
"name": name,
"base": "sd-1",
"type": "main",
"format": "diffusers",
"path": f"/models/{key}",
"hash": "abc123",
"source": "/src",
"source_type": "path",
"file_size": 1024,
}
)
with db.transaction() as cursor:
cursor.execute("INSERT INTO models (id, config) VALUES (?, ?)", (key, config))
@pytest.fixture
def storage(db: SqliteDatabase) -> SqlModelModelRelationshipRecordStorage:
return SqlModelModelRelationshipRecordStorage(db=db)
@pytest.fixture
def models(db: SqliteDatabase) -> tuple[str, str, str]:
keys = ("model_a", "model_b", "model_c")
for k in keys:
_add_model(db, k, name=k)
return keys
def test_add_and_get_relationship(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, b, _ = models
storage.add_model_relationship(a, b)
related = storage.get_related_model_keys(a)
assert b in related
def test_bidirectional(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, b, _ = models
storage.add_model_relationship(a, b)
assert a in storage.get_related_model_keys(b)
assert b in storage.get_related_model_keys(a)
def test_self_relationship_raises(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, _, _ = models
with pytest.raises(ValueError, match="Cannot relate a model to itself"):
storage.add_model_relationship(a, a)
def test_remove_relationship(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, b, _ = models
storage.add_model_relationship(a, b)
storage.remove_model_relationship(a, b)
assert b not in storage.get_related_model_keys(a)
def test_duplicate_add_is_idempotent(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, b, _ = models
storage.add_model_relationship(a, b)
storage.add_model_relationship(a, b) # should not raise
related = storage.get_related_model_keys(a)
assert related.count(b) == 1
def test_get_related_batch(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, b, c = models
storage.add_model_relationship(a, b)
storage.add_model_relationship(a, c)
related = storage.get_related_model_keys_batch([a])
assert set(related) == {b, c}
def test_no_relationships(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
a, _, _ = models
assert storage.get_related_model_keys(a) == []

View File

@@ -0,0 +1,89 @@
"""Tests for SqlModelStylePresetRecordsStorage."""
import pytest
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.style_preset_records.style_preset_records_common import (
PresetData,
StylePresetChanges,
StylePresetNotFoundError,
StylePresetWithoutId,
)
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
@pytest.fixture
def storage(db: SqliteDatabase) -> SqlModelStylePresetRecordsStorage:
return SqlModelStylePresetRecordsStorage(db=db)
def _make_preset(name: str = "Test Preset", preset_type: str = "user") -> StylePresetWithoutId:
return StylePresetWithoutId(
name=name,
preset_data=PresetData(positive_prompt="a cat", negative_prompt=""),
type=preset_type,
)
def test_create_and_get(storage):
preset = storage.create(_make_preset("My Preset"))
assert preset.name == "My Preset"
fetched = storage.get(preset.id)
assert fetched.name == "My Preset"
def test_get_nonexistent(storage):
with pytest.raises(StylePresetNotFoundError):
storage.get("nonexistent")
def test_update_name(storage):
preset = storage.create(_make_preset("Original"))
updated = storage.update(preset.id, StylePresetChanges(name="Updated", type=None))
assert updated.name == "Updated"
def test_update_preset_data(storage):
preset = storage.create(_make_preset())
new_data = PresetData(positive_prompt="a dog", negative_prompt="ugly")
updated = storage.update(preset.id, StylePresetChanges(preset_data=new_data, type=None))
assert updated.preset_data.positive_prompt == "a dog"
def test_delete(storage):
preset = storage.create(_make_preset())
storage.delete(preset.id)
with pytest.raises(StylePresetNotFoundError):
storage.get(preset.id)
def test_get_many(storage):
storage.create(_make_preset("Preset A"))
storage.create(_make_preset("Preset B"))
results = storage.get_many()
# Filter out any default presets
user_presets = [r for r in results if r.type == "user"]
assert len(user_presets) == 2
def test_get_many_filter_by_type(storage):
storage.create(_make_preset("User Preset", "user"))
# There may be defaults loaded; just verify filtering works
user_presets = storage.get_many(type="user")
assert all(p.type == "user" for p in user_presets)
def test_create_many(storage):
presets = [_make_preset(f"Preset {i}") for i in range(3)]
storage.create_many(presets)
all_presets = storage.get_many(type="user")
assert len(all_presets) >= 3
def test_get_many_ordered_by_name(storage):
storage.create(_make_preset("Zebra"))
storage.create(_make_preset("Alpha"))
results = storage.get_many(type="user")
names = [r.name for r in results]
assert names == sorted(names, key=str.lower)

View File

@@ -0,0 +1,150 @@
"""Tests for UserServiceSqlModel."""
import pytest
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.users.users_common import UserCreateRequest, UserUpdateRequest
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
@pytest.fixture
def user_service(db: SqliteDatabase) -> UserServiceSqlModel:
return UserServiceSqlModel(db=db)
def test_create_user(user_service: UserServiceSqlModel):
user = user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test User", password="TestPassword123")
)
assert user.email == "test@example.com"
assert user.display_name == "Test User"
assert user.is_admin is False
assert user.is_active is True
def test_create_user_weak_password(user_service: UserServiceSqlModel):
with pytest.raises(ValueError, match="at least 8 characters"):
user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test", password="weak"),
strict_password_checking=True,
)
def test_create_user_weak_password_non_strict(user_service: UserServiceSqlModel):
user = user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test", password="weak"),
strict_password_checking=False,
)
assert user.email == "test@example.com"
def test_create_duplicate_user(user_service: UserServiceSqlModel):
data = UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
user_service.create(data)
with pytest.raises(ValueError, match="already exists"):
user_service.create(data)
def test_get_user(user_service: UserServiceSqlModel):
created = user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
)
fetched = user_service.get(created.user_id)
assert fetched is not None
assert fetched.user_id == created.user_id
def test_get_nonexistent_user(user_service: UserServiceSqlModel):
assert user_service.get("nonexistent-id") is None
def test_get_user_by_email(user_service: UserServiceSqlModel):
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
fetched = user_service.get_by_email("test@example.com")
assert fetched is not None
assert fetched.email == "test@example.com"
def test_update_user(user_service: UserServiceSqlModel):
user = user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
)
updated = user_service.update(user.user_id, UserUpdateRequest(display_name="Updated", is_admin=True))
assert updated.display_name == "Updated"
assert updated.is_admin is True
def test_delete_user(user_service: UserServiceSqlModel):
user = user_service.create(
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
)
user_service.delete(user.user_id)
assert user_service.get(user.user_id) is None
def test_authenticate_valid(user_service: UserServiceSqlModel):
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
auth = user_service.authenticate("test@example.com", "TestPassword123")
assert auth is not None
assert auth.email == "test@example.com"
assert auth.last_login_at is not None
def test_authenticate_invalid_password(user_service: UserServiceSqlModel):
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
assert user_service.authenticate("test@example.com", "WrongPassword") is None
def test_authenticate_nonexistent(user_service: UserServiceSqlModel):
assert user_service.authenticate("none@example.com", "TestPassword123") is None
def test_has_admin(user_service: UserServiceSqlModel):
assert user_service.has_admin() is False
user_service.create(
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
)
assert user_service.has_admin() is True
def test_create_admin(user_service: UserServiceSqlModel):
admin = user_service.create_admin(
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123")
)
assert admin.is_admin is True
def test_create_admin_when_exists(user_service: UserServiceSqlModel):
user_service.create_admin(
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123")
)
with pytest.raises(ValueError, match="already exists"):
user_service.create_admin(
UserCreateRequest(email="admin2@example.com", display_name="Admin2", password="AdminPassword123")
)
def test_list_users(user_service: UserServiceSqlModel):
for i in range(5):
user_service.create(
UserCreateRequest(email=f"test{i}@example.com", display_name=f"User {i}", password="TestPassword123")
)
# Migration 27 creates a 'system' user, so total = 5 + 1
assert len(user_service.list_users()) == 6
assert len(user_service.list_users(limit=2)) == 2
def test_get_admin_email(user_service: UserServiceSqlModel):
assert user_service.get_admin_email() is None
user_service.create(
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
)
assert user_service.get_admin_email() == "admin@example.com"
def test_count_admins(user_service: UserServiceSqlModel):
assert user_service.count_admins() == 0
user_service.create(
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
)
assert user_service.count_admins() == 1