mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Compare commits
7 Commits
main
...
feature/sq
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ca10b10b2e | ||
|
|
80120d3312 | ||
|
|
5d3e30eb67 | ||
|
|
0a428ffff4 | ||
|
|
4f7343b4e4 | ||
|
|
aeb6643879 | ||
|
|
1a06fe0d8d |
@@ -5,14 +5,16 @@ from logging import Logger
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.app.services.app_settings import AppSettingsService
|
||||
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
|
||||
from invokeai.app.services.auth.token_service import set_jwt_secret
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlite import SqliteBoardImageRecordStorage
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
|
||||
from invokeai.app.services.board_images.board_images_default import BoardImagesService
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
|
||||
from invokeai.app.services.boards.boards_default import BoardService
|
||||
from invokeai.app.services.bulk_download.bulk_download_default import BulkDownloadService
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlite import ClientStatePersistenceSqlite
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
|
||||
ClientStatePersistenceSqlModel,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.download.download_default import DownloadQueueService
|
||||
from invokeai.app.services.events.events_fastapievents import FastAPIEventService
|
||||
@@ -20,7 +22,7 @@ from invokeai.app.services.external_generation.external_generation_default impor
|
||||
from invokeai.app.services.external_generation.providers import GeminiProvider, OpenAIProvider
|
||||
from invokeai.app.services.external_generation.startup import sync_configured_external_starter_models
|
||||
from invokeai.app.services.image_files.image_files_disk import DiskImageFileStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
|
||||
from invokeai.app.services.images.images_default import ImageService
|
||||
from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
@@ -28,9 +30,9 @@ from invokeai.app.services.invocation_stats.invocation_stats_default import Invo
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_images.model_images_default import ModelImageFileStorageDisk
|
||||
from invokeai.app.services.model_manager.model_manager_default import ModelManagerService
|
||||
from invokeai.app.services.model_records.model_records_sql import ModelRecordServiceSQL
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlite import (
|
||||
SqliteModelRelationshipRecordStorage,
|
||||
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
|
||||
SqlModelModelRelationshipRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.model_relationships.model_relationships_default import ModelRelationshipsService
|
||||
from invokeai.app.services.names.names_default import SimpleNameService
|
||||
@@ -40,13 +42,13 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
DefaultSessionProcessor,
|
||||
DefaultSessionRunner,
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlite import SqliteStylePresetRecordsStorage
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
|
||||
from invokeai.app.services.urls.urls_default import LocalUrlService
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
|
||||
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
|
||||
from invokeai.app.services.workflow_records.workflow_records_sqlmodel import SqlModelWorkflowRecordsStorage
|
||||
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
AnimaConditioningInfo,
|
||||
@@ -110,7 +112,7 @@ class ApiDependencies:
|
||||
db = init_db(config=config, logger=logger, image_files=image_files)
|
||||
|
||||
# Initialize JWT secret from database
|
||||
app_settings = AppSettingsService(db=db)
|
||||
app_settings = AppSettingsServiceSqlModel(db=db)
|
||||
jwt_secret = app_settings.get_jwt_secret()
|
||||
set_jwt_secret(jwt_secret)
|
||||
logger.info("JWT secret loaded from database")
|
||||
@@ -118,13 +120,13 @@ class ApiDependencies:
|
||||
configuration = config
|
||||
logger = logger
|
||||
|
||||
board_image_records = SqliteBoardImageRecordStorage(db=db)
|
||||
board_image_records = SqlModelBoardImageRecordStorage(db=db)
|
||||
board_images = BoardImagesService()
|
||||
board_records = SqliteBoardRecordStorage(db=db)
|
||||
board_records = SqlModelBoardRecordStorage(db=db)
|
||||
boards = BoardService()
|
||||
events = FastAPIEventService(event_handler_id, loop=loop)
|
||||
bulk_download = BulkDownloadService()
|
||||
image_records = SqliteImageRecordStorage(db=db)
|
||||
image_records = SqlModelImageRecordStorage(db=db)
|
||||
images = ImageService()
|
||||
invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size)
|
||||
tensors = ObjectSerializerForwardCache(
|
||||
@@ -152,7 +154,7 @@ class ApiDependencies:
|
||||
),
|
||||
)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_record_service = ModelRecordServiceSQL(db=db, logger=logger)
|
||||
model_record_service = ModelRecordServiceSqlModel(db=db, logger=logger)
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
model_record_service=model_record_service,
|
||||
@@ -169,18 +171,18 @@ class ApiDependencies:
|
||||
)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_relationships = ModelRelationshipsService()
|
||||
model_relationship_records = SqliteModelRelationshipRecordStorage(db=db)
|
||||
model_relationship_records = SqlModelModelRelationshipRecordStorage(db=db)
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
session_queue = SqliteSessionQueue(db=db)
|
||||
session_queue = SqlModelSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqliteWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqliteStylePresetRecordsStorage(db=db)
|
||||
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)
|
||||
style_preset_image_files = StylePresetImageFileStorageDisk(style_presets_folder / "images")
|
||||
workflow_thumbnails = WorkflowThumbnailFileStorageDisk(workflow_thumbnails_folder)
|
||||
client_state_persistence = ClientStatePersistenceSqlite(db=db)
|
||||
users = UserService(db=db)
|
||||
client_state_persistence = ClientStatePersistenceSqlModel(db=db)
|
||||
users = UserServiceSqlModel(db=db)
|
||||
|
||||
services = InvocationServices(
|
||||
board_image_records=board_image_records,
|
||||
|
||||
37
invokeai/app/services/app_settings/app_settings_sqlmodel.py
Normal file
37
invokeai/app/services/app_settings/app_settings_sqlmodel.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.shared.sqlite.models import AppSettingTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class AppSettingsServiceSqlModel:
|
||||
"""SQLModel implementation for application-level settings."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
self._db = db
|
||||
|
||||
def get(self, key: str) -> Optional[str]:
|
||||
try:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(AppSettingTable, key)
|
||||
return row.value if row else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def set(self, key: str, value: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(AppSettingTable, key)
|
||||
if existing is not None:
|
||||
existing.value = value
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(AppSettingTable(key=key, value=value))
|
||||
|
||||
def get_jwt_secret(self) -> str:
|
||||
secret = self.get("jwt_secret")
|
||||
if secret is None:
|
||||
raise RuntimeError(
|
||||
"JWT secret not found in database. This should have been created during database migration. "
|
||||
"Please ensure database migrations have been run successfully."
|
||||
)
|
||||
return secret
|
||||
@@ -0,0 +1,145 @@
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.board_image_records.board_image_records_base import BoardImageRecordStorageBase
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ASSETS_CATEGORIES,
|
||||
IMAGE_CATEGORIES,
|
||||
ImageCategory,
|
||||
ImageRecord,
|
||||
deserialize_image_record,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import BoardImageTable, ImageTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqlModelBoardImageRecordStorage(BoardImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(BoardImageTable, image_name)
|
||||
if existing is not None:
|
||||
existing.board_id = board_id
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(BoardImageTable(board_id=board_id, image_name=image_name))
|
||||
|
||||
def remove_image_from_board(self, image_name: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(BoardImageTable, image_name)
|
||||
if existing is not None:
|
||||
session.delete(existing)
|
||||
|
||||
def get_images_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Join board_images with images
|
||||
stmt = (
|
||||
select(ImageTable)
|
||||
.join(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
|
||||
.where(col(BoardImageTable.board_id) == board_id)
|
||||
.order_by(col(BoardImageTable.updated_at).desc())
|
||||
)
|
||||
results = session.exec(stmt).all()
|
||||
images = [deserialize_image_record(_image_to_dict(r)) for r in results]
|
||||
|
||||
# Total count of all images
|
||||
count_stmt = select(func.count()).select_from(ImageTable)
|
||||
count = session.exec(count_stmt).one()
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
|
||||
|
||||
def get_all_board_image_names_for_board(
|
||||
self,
|
||||
board_id: str,
|
||||
categories: list[ImageCategory] | None,
|
||||
is_intermediate: bool | None,
|
||||
) -> list[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(ImageTable.image_name).outerjoin(
|
||||
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
|
||||
)
|
||||
|
||||
if board_id == "none":
|
||||
stmt = stmt.where(col(BoardImageTable.board_id).is_(None))
|
||||
else:
|
||||
stmt = stmt.where(col(BoardImageTable.board_id) == board_id)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
stmt = stmt.where(col(ImageTable.image_category).in_(category_strings))
|
||||
|
||||
if is_intermediate is not None:
|
||||
stmt = stmt.where(col(ImageTable.is_intermediate) == is_intermediate)
|
||||
|
||||
results = session.exec(stmt).all()
|
||||
return list(results)
|
||||
|
||||
def get_board_for_image(self, image_name: str) -> Optional[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(BoardImageTable, image_name)
|
||||
if row is None:
|
||||
return None
|
||||
return row.board_id
|
||||
|
||||
def get_image_count_for_board(self, board_id: str) -> int:
|
||||
category_strings = [c.value for c in set(IMAGE_CATEGORIES)]
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(BoardImageTable)
|
||||
.join(ImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
|
||||
.where(
|
||||
col(ImageTable.is_intermediate) == False, # noqa: E712
|
||||
col(ImageTable.image_category).in_(category_strings),
|
||||
col(BoardImageTable.board_id) == board_id,
|
||||
)
|
||||
)
|
||||
count = session.exec(stmt).one()
|
||||
return count
|
||||
|
||||
def get_asset_count_for_board(self, board_id: str) -> int:
|
||||
category_strings = [c.value for c in set(ASSETS_CATEGORIES)]
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(BoardImageTable)
|
||||
.join(ImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
|
||||
.where(
|
||||
col(ImageTable.is_intermediate) == False, # noqa: E712
|
||||
col(ImageTable.image_category).in_(category_strings),
|
||||
col(BoardImageTable.board_id) == board_id,
|
||||
)
|
||||
)
|
||||
count = session.exec(stmt).one()
|
||||
return count
|
||||
|
||||
|
||||
def _image_to_dict(row: ImageTable) -> dict:
|
||||
"""Convert an ImageTable row to a dict compatible with deserialize_image_record."""
|
||||
return {
|
||||
"image_name": row.image_name,
|
||||
"image_origin": row.image_origin,
|
||||
"image_category": row.image_category,
|
||||
"width": row.width,
|
||||
"height": row.height,
|
||||
"session_id": row.session_id,
|
||||
"node_id": row.node_id,
|
||||
"metadata": row.metadata_,
|
||||
"is_intermediate": row.is_intermediate,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
"deleted_at": row.deleted_at,
|
||||
"starred": row.starred,
|
||||
"has_workflow": row.has_workflow,
|
||||
}
|
||||
177
invokeai/app/services/board_records/board_records_sqlmodel.py
Normal file
177
invokeai/app/services/board_records/board_records_sqlmodel.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
||||
from invokeai.app.services.board_records.board_records_common import (
|
||||
BoardChanges,
|
||||
BoardRecord,
|
||||
BoardRecordDeleteException,
|
||||
BoardRecordNotFoundException,
|
||||
BoardRecordOrderBy,
|
||||
BoardRecordSaveException,
|
||||
BoardVisibility,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import BoardTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
def _to_record(row: BoardTable) -> BoardRecord:
|
||||
"""Convert a SQLModel BoardTable row to a BoardRecord pydantic model.
|
||||
|
||||
Must be called while the row is still bound to an active Session.
|
||||
"""
|
||||
try:
|
||||
visibility = BoardVisibility(row.board_visibility)
|
||||
except ValueError:
|
||||
visibility = BoardVisibility.Private
|
||||
|
||||
return BoardRecord(
|
||||
board_id=row.board_id,
|
||||
board_name=row.board_name,
|
||||
user_id=row.user_id,
|
||||
cover_image_name=row.cover_image_name,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
deleted_at=row.deleted_at,
|
||||
archived=row.archived,
|
||||
board_visibility=visibility,
|
||||
)
|
||||
|
||||
|
||||
class SqlModelBoardRecordStorage(BoardRecordStorageBase):
|
||||
"""Board record storage using SQLModel."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def delete(self, board_id: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
board = session.get(BoardTable, board_id)
|
||||
if board:
|
||||
session.delete(board)
|
||||
except Exception as e:
|
||||
raise BoardRecordDeleteException from e
|
||||
|
||||
def save(self, board_name: str, user_id: str) -> BoardRecord:
|
||||
board_id = uuid_string()
|
||||
board = BoardTable(board_id=board_id, board_name=board_name, user_id=user_id)
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
session.add(board)
|
||||
session.flush()
|
||||
return _to_record(board)
|
||||
except Exception as e:
|
||||
raise BoardRecordSaveException from e
|
||||
|
||||
def get(self, board_id: str) -> BoardRecord:
|
||||
with self._db.get_readonly_session() as session:
|
||||
board = session.get(BoardTable, board_id)
|
||||
if board is None:
|
||||
raise BoardRecordNotFoundException
|
||||
return _to_record(board)
|
||||
|
||||
def update(self, board_id: str, changes: BoardChanges) -> BoardRecord:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
board = session.get(BoardTable, board_id)
|
||||
if board is None:
|
||||
raise BoardRecordNotFoundException
|
||||
|
||||
if changes.board_name is not None:
|
||||
board.board_name = changes.board_name
|
||||
if changes.cover_image_name is not None:
|
||||
board.cover_image_name = changes.cover_image_name
|
||||
if changes.archived is not None:
|
||||
board.archived = changes.archived
|
||||
if changes.board_visibility is not None:
|
||||
board.board_visibility = changes.board_visibility.value
|
||||
|
||||
session.add(board)
|
||||
session.flush()
|
||||
return _to_record(board)
|
||||
except BoardRecordNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BoardRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
include_archived: bool = False,
|
||||
) -> OffsetPaginatedResults[BoardRecord]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Build filter conditions
|
||||
conditions = []
|
||||
|
||||
if not is_admin:
|
||||
conditions.append(
|
||||
(col(BoardTable.user_id) == user_id) | (col(BoardTable.board_visibility).in_(["shared", "public"]))
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
conditions.append(col(BoardTable.archived) == False) # noqa: E712
|
||||
|
||||
# Count query
|
||||
count_stmt = select(func.count()).select_from(BoardTable)
|
||||
for cond in conditions:
|
||||
count_stmt = count_stmt.where(cond)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Data query
|
||||
stmt = select(BoardTable)
|
||||
for cond in conditions:
|
||||
stmt = stmt.where(cond)
|
||||
|
||||
# Apply ordering
|
||||
order_col = (
|
||||
col(BoardTable.created_at) if order_by == BoardRecordOrderBy.CreatedAt else col(BoardTable.board_name)
|
||||
)
|
||||
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
|
||||
stmt = stmt.offset(offset).limit(limit)
|
||||
|
||||
results = session.exec(stmt).all()
|
||||
boards = [_to_record(r) for r in results]
|
||||
|
||||
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=total)
|
||||
|
||||
def get_all(
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool,
|
||||
order_by: BoardRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
include_archived: bool = False,
|
||||
) -> list[BoardRecord]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(BoardTable)
|
||||
|
||||
if not is_admin:
|
||||
stmt = stmt.where(
|
||||
(col(BoardTable.user_id) == user_id) | (col(BoardTable.board_visibility).in_(["shared", "public"]))
|
||||
)
|
||||
|
||||
if not include_archived:
|
||||
stmt = stmt.where(col(BoardTable.archived) == False) # noqa: E712
|
||||
|
||||
# Apply ordering
|
||||
if order_by == BoardRecordOrderBy.Name:
|
||||
order_col = col(BoardTable.board_name)
|
||||
else:
|
||||
order_col = col(BoardTable.created_at)
|
||||
|
||||
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
|
||||
|
||||
results = session.exec(stmt).all()
|
||||
boards = [_to_record(r) for r in results]
|
||||
|
||||
return boards
|
||||
@@ -0,0 +1,41 @@
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_base import ClientStatePersistenceABC
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.models import ClientStateTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class ClientStatePersistenceSqlModel(ClientStatePersistenceABC):
|
||||
"""SQLModel implementation for client state persistence."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
|
||||
def set_by_key(self, user_id: str, key: str, value: str) -> str:
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(ClientStateTable, (user_id, key))
|
||||
if existing is not None:
|
||||
existing.value = value
|
||||
session.add(existing)
|
||||
else:
|
||||
session.add(ClientStateTable(user_id=user_id, key=key, value=value))
|
||||
return value
|
||||
|
||||
def get_by_key(self, user_id: str, key: str) -> str | None:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ClientStateTable, (user_id, key))
|
||||
if row is None:
|
||||
return None
|
||||
return row.value
|
||||
|
||||
def delete(self, user_id: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(ClientStateTable).where(col(ClientStateTable.user_id) == user_id)
|
||||
rows = session.exec(stmt).all()
|
||||
for row in rows:
|
||||
session.delete(row)
|
||||
367
invokeai/app/services/image_records/image_records_sqlmodel.py
Normal file
367
invokeai/app/services/image_records/image_records_sqlmodel.py
Normal file
@@ -0,0 +1,367 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator
|
||||
from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
ImageRecordNotFoundException,
|
||||
ImageRecordSaveException,
|
||||
ResourceOrigin,
|
||||
deserialize_image_record,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import BoardImageTable, ImageTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
def _to_dict(row: ImageTable) -> dict:
|
||||
return {
|
||||
"image_name": row.image_name,
|
||||
"image_origin": row.image_origin,
|
||||
"image_category": row.image_category,
|
||||
"width": row.width,
|
||||
"height": row.height,
|
||||
"session_id": row.session_id,
|
||||
"node_id": row.node_id,
|
||||
"metadata": row.metadata_,
|
||||
"is_intermediate": row.is_intermediate,
|
||||
"created_at": row.created_at,
|
||||
"updated_at": row.updated_at,
|
||||
"deleted_at": row.deleted_at,
|
||||
"starred": row.starred,
|
||||
"has_workflow": row.has_workflow,
|
||||
}
|
||||
|
||||
|
||||
class SqlModelImageRecordStorage(ImageRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def get(self, image_name: str) -> ImageRecord:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ImageTable, image_name)
|
||||
if row is None:
|
||||
raise ImageRecordNotFoundException
|
||||
return deserialize_image_record(_to_dict(row))
|
||||
|
||||
def get_user_id(self, image_name: str) -> Optional[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ImageTable, image_name)
|
||||
if row is None:
|
||||
return None
|
||||
return row.user_id
|
||||
|
||||
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ImageTable, image_name)
|
||||
if row is None:
|
||||
raise ImageRecordNotFoundException
|
||||
if row.metadata_ is None:
|
||||
return None
|
||||
return MetadataFieldValidator.validate_json(row.metadata_)
|
||||
|
||||
def update(self, image_name: str, changes: ImageRecordChanges) -> None:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
row = session.get(ImageTable, image_name)
|
||||
if row is None:
|
||||
raise ImageRecordNotFoundException
|
||||
|
||||
if changes.image_category is not None:
|
||||
row.image_category = changes.image_category.value
|
||||
if changes.session_id is not None:
|
||||
row.session_id = changes.session_id
|
||||
if changes.is_intermediate is not None:
|
||||
row.is_intermediate = changes.is_intermediate
|
||||
if changes.starred is not None:
|
||||
row.starred = changes.starred
|
||||
|
||||
session.add(row)
|
||||
except ImageRecordNotFoundException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
offset: int = 0,
|
||||
limit: int = 10,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> OffsetPaginatedResults[ImageRecord]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Base query with left join to board_images
|
||||
stmt = select(ImageTable).outerjoin(
|
||||
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
|
||||
)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(ImageTable)
|
||||
.outerjoin(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
stmt, count_stmt = self._apply_image_filters(
|
||||
stmt,
|
||||
count_stmt,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
user_id,
|
||||
is_admin,
|
||||
)
|
||||
|
||||
# Count
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Ordering
|
||||
if starred_first:
|
||||
stmt = stmt.order_by(
|
||||
col(ImageTable.starred).desc(),
|
||||
col(ImageTable.created_at).desc()
|
||||
if order_dir == SQLiteDirection.Descending
|
||||
else col(ImageTable.created_at).asc(),
|
||||
)
|
||||
else:
|
||||
stmt = stmt.order_by(
|
||||
col(ImageTable.created_at).desc()
|
||||
if order_dir == SQLiteDirection.Descending
|
||||
else col(ImageTable.created_at).asc(),
|
||||
)
|
||||
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
results = session.exec(stmt).all()
|
||||
images = [deserialize_image_record(_to_dict(r)) for r in results]
|
||||
|
||||
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=total)
|
||||
|
||||
def delete(self, image_name: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
row = session.get(ImageTable, image_name)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
except Exception as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def delete_many(self, image_names: list[str]) -> None:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
stmt = select(ImageTable).where(col(ImageTable.image_name).in_(image_names))
|
||||
rows = session.exec(stmt).all()
|
||||
for row in rows:
|
||||
session.delete(row)
|
||||
except Exception as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
|
||||
def get_intermediates_count(self, user_id: Optional[str] = None) -> int:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(ImageTable)
|
||||
.where(
|
||||
col(ImageTable.is_intermediate) == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(col(ImageTable.user_id) == user_id)
|
||||
count = session.exec(stmt).one()
|
||||
return count
|
||||
|
||||
def delete_intermediates(self) -> list[str]:
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
stmt = select(ImageTable).where(col(ImageTable.is_intermediate) == True) # noqa: E712
|
||||
rows = session.exec(stmt).all()
|
||||
names = [r.image_name for r in rows]
|
||||
for row in rows:
|
||||
session.delete(row)
|
||||
except Exception as e:
|
||||
raise ImageRecordDeleteException from e
|
||||
return names
|
||||
|
||||
def save(
|
||||
self,
|
||||
image_name: str,
|
||||
image_origin: ResourceOrigin,
|
||||
image_category: ImageCategory,
|
||||
width: int,
|
||||
height: int,
|
||||
has_workflow: bool,
|
||||
is_intermediate: Optional[bool] = False,
|
||||
starred: Optional[bool] = False,
|
||||
session_id: Optional[str] = None,
|
||||
node_id: Optional[str] = None,
|
||||
metadata: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> datetime:
|
||||
row = ImageTable(
|
||||
image_name=image_name,
|
||||
image_origin=image_origin.value,
|
||||
image_category=image_category.value,
|
||||
width=width,
|
||||
height=height,
|
||||
session_id=session_id,
|
||||
node_id=node_id,
|
||||
metadata_=metadata,
|
||||
is_intermediate=is_intermediate or False,
|
||||
starred=starred or False,
|
||||
has_workflow=has_workflow,
|
||||
user_id=user_id or "system",
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
try:
|
||||
session.add(row)
|
||||
session.flush()
|
||||
# With expire_on_commit=False, row.created_at is still accessible
|
||||
return (
|
||||
row.created_at
|
||||
if isinstance(row.created_at, datetime)
|
||||
else datetime.fromisoformat(str(row.created_at))
|
||||
)
|
||||
except Exception as e:
|
||||
raise ImageRecordSaveException from e
|
||||
|
||||
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(ImageTable)
|
||||
.join(BoardImageTable, col(ImageTable.image_name) == col(BoardImageTable.image_name))
|
||||
.where(
|
||||
col(BoardImageTable.board_id) == board_id,
|
||||
col(ImageTable.is_intermediate) == False, # noqa: E712
|
||||
)
|
||||
.order_by(col(ImageTable.starred).desc(), col(ImageTable.created_at).desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = session.exec(stmt).first()
|
||||
if row is None:
|
||||
return None
|
||||
return deserialize_image_record(_to_dict(row))
|
||||
|
||||
def get_image_names(
|
||||
self,
|
||||
starred_first: bool = True,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
image_origin: Optional[ResourceOrigin] = None,
|
||||
categories: Optional[list[ImageCategory]] = None,
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_admin: bool = False,
|
||||
) -> ImageNamesResult:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Base query
|
||||
stmt = select(ImageTable.image_name).outerjoin(
|
||||
BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name)
|
||||
)
|
||||
|
||||
# Dummy count stmt for filter reuse (we won't use it here)
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(ImageTable)
|
||||
.outerjoin(BoardImageTable, col(BoardImageTable.image_name) == col(ImageTable.image_name))
|
||||
)
|
||||
|
||||
stmt, count_stmt = self._apply_image_filters(
|
||||
stmt,
|
||||
count_stmt,
|
||||
image_origin,
|
||||
categories,
|
||||
is_intermediate,
|
||||
board_id,
|
||||
search_term,
|
||||
user_id,
|
||||
is_admin,
|
||||
)
|
||||
|
||||
# Starred count
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
starred_stmt = count_stmt.where(col(ImageTable.starred) == True) # noqa: E712
|
||||
starred_count = session.exec(starred_stmt).one()
|
||||
|
||||
# Ordering
|
||||
if starred_first:
|
||||
stmt = stmt.order_by(
|
||||
col(ImageTable.starred).desc(),
|
||||
col(ImageTable.created_at).desc()
|
||||
if order_dir == SQLiteDirection.Descending
|
||||
else col(ImageTable.created_at).asc(),
|
||||
)
|
||||
else:
|
||||
stmt = stmt.order_by(
|
||||
col(ImageTable.created_at).desc()
|
||||
if order_dir == SQLiteDirection.Descending
|
||||
else col(ImageTable.created_at).asc(),
|
||||
)
|
||||
|
||||
results = session.exec(stmt).all()
|
||||
|
||||
return ImageNamesResult(
|
||||
image_names=list(results),
|
||||
starred_count=starred_count,
|
||||
total_count=len(results),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_image_filters(
|
||||
stmt, count_stmt, image_origin, categories, is_intermediate, board_id, search_term, user_id, is_admin
|
||||
):
|
||||
"""Apply common filters to both data and count queries."""
|
||||
if image_origin is not None:
|
||||
cond = col(ImageTable.image_origin) == image_origin.value
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if categories is not None:
|
||||
category_strings = [c.value for c in set(categories)]
|
||||
cond = col(ImageTable.image_category).in_(category_strings)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if is_intermediate is not None:
|
||||
cond = col(ImageTable.is_intermediate) == is_intermediate
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if board_id == "none":
|
||||
cond = col(BoardImageTable.board_id).is_(None)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
if user_id is not None and not is_admin:
|
||||
user_cond = col(ImageTable.user_id) == user_id
|
||||
stmt = stmt.where(user_cond)
|
||||
count_stmt = count_stmt.where(user_cond)
|
||||
elif board_id is not None:
|
||||
cond = col(BoardImageTable.board_id) == board_id
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if search_term:
|
||||
term = f"%{search_term.lower()}%"
|
||||
cond = col(ImageTable.metadata_).like(term) | col(ImageTable.created_at).like(term)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
return stmt, count_stmt
|
||||
235
invokeai/app/services/model_records/model_records_sqlmodel.py
Normal file
235
invokeai/app/services/model_records/model_records_sqlmodel.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""SQLModel implementation of ModelRecordServiceBase."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import pydantic
|
||||
from sqlalchemy import func, literal_column
|
||||
from sqlmodel import select
|
||||
|
||||
from invokeai.app.services.model_records.model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
ModelRecordServiceBase,
|
||||
ModelSummary,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import ModelTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.configs.factory import AnyModelConfig, ModelConfigFactory
|
||||
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType
|
||||
|
||||
# Mapping from ModelRecordOrderBy to column expressions
|
||||
_ORDER_COLS = {
|
||||
ModelRecordOrderBy.Default: "type, base, name, format",
|
||||
ModelRecordOrderBy.Type: "type",
|
||||
ModelRecordOrderBy.Base: "base",
|
||||
ModelRecordOrderBy.Name: "name",
|
||||
ModelRecordOrderBy.Format: "format",
|
||||
}
|
||||
|
||||
|
||||
class ModelRecordServiceSqlModel(ModelRecordServiceBase):
|
||||
"""SQLModel implementation of ModelConfigStore."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase, logger: logging.Logger):
|
||||
super().__init__()
|
||||
self._db = db
|
||||
self._logger = logger
|
||||
|
||||
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
|
||||
row = ModelTable(id=config.key, config=config.model_dump_json())
|
||||
try:
|
||||
with self._db.get_session() as session:
|
||||
session.add(row)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "UNIQUE constraint failed" in err_str:
|
||||
if "models.path" in err_str:
|
||||
msg = f"A model with path '{config.path}' is already installed"
|
||||
elif "models.name" in err_str:
|
||||
msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
|
||||
else:
|
||||
msg = f"A model with key '{config.key}' is already installed"
|
||||
raise DuplicateModelException(msg) from e
|
||||
raise
|
||||
return self.get_model(config.key)
|
||||
|
||||
def del_model(self, key: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(ModelTable, key)
|
||||
if row is None:
|
||||
raise UnknownModelException("model not found")
|
||||
session.delete(row)
|
||||
|
||||
def update_model(self, key: str, changes: ModelRecordChanges, allow_class_change: bool = False) -> AnyModelConfig:
|
||||
record = self.get_model(key)
|
||||
|
||||
if allow_class_change:
|
||||
record_as_dict = record.model_dump()
|
||||
for field_name in changes.model_fields_set:
|
||||
record_as_dict[field_name] = getattr(changes, field_name)
|
||||
record = ModelConfigFactory.from_dict(record_as_dict)
|
||||
else:
|
||||
for field_name in changes.model_fields_set:
|
||||
setattr(record, field_name, getattr(changes, field_name))
|
||||
|
||||
json_serialized = record.model_dump_json()
|
||||
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(ModelTable, key)
|
||||
if row is None:
|
||||
raise UnknownModelException("model not found")
|
||||
row.config = json_serialized
|
||||
session.add(row)
|
||||
|
||||
return self.get_model(key)
|
||||
|
||||
def replace_model(self, key: str, new_config: AnyModelConfig) -> AnyModelConfig:
|
||||
if key != new_config.key:
|
||||
raise ValueError("key does not match new_config.key")
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(ModelTable, key)
|
||||
if row is None:
|
||||
raise UnknownModelException("model not found")
|
||||
row.config = new_config.model_dump_json()
|
||||
session.add(row)
|
||||
return self.get_model(key)
|
||||
|
||||
def get_model(self, key: str) -> AnyModelConfig:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ModelTable, key)
|
||||
if row is None:
|
||||
raise UnknownModelException("model not found")
|
||||
return ModelConfigFactory.from_dict(json.loads(row.config))
|
||||
|
||||
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(ModelTable).where(literal_column("hash") == hash)
|
||||
row = session.exec(stmt).first()
|
||||
if row is None:
|
||||
raise UnknownModelException("model not found")
|
||||
return ModelConfigFactory.from_dict(json.loads(row.config))
|
||||
|
||||
def exists(self, key: str) -> bool:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(ModelTable, key)
|
||||
return row is not None
|
||||
|
||||
def search_by_attr(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
base_model: Optional[BaseModelType] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
model_format: Optional[ModelFormat] = None,
|
||||
order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default,
|
||||
) -> List[AnyModelConfig]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(ModelTable)
|
||||
|
||||
if model_name:
|
||||
stmt = stmt.where(literal_column("name") == model_name)
|
||||
if base_model:
|
||||
stmt = stmt.where(literal_column("base") == base_model)
|
||||
if model_type:
|
||||
stmt = stmt.where(literal_column("type") == model_type)
|
||||
if model_format:
|
||||
stmt = stmt.where(literal_column("format") == model_format)
|
||||
|
||||
# Apply ordering via the generated columns
|
||||
if order_by == ModelRecordOrderBy.Default:
|
||||
stmt = stmt.order_by(
|
||||
literal_column("type"),
|
||||
literal_column("base"),
|
||||
literal_column("name"),
|
||||
literal_column("format"),
|
||||
)
|
||||
elif order_by == ModelRecordOrderBy.Type:
|
||||
stmt = stmt.order_by(literal_column("type"))
|
||||
elif order_by == ModelRecordOrderBy.Base:
|
||||
stmt = stmt.order_by(literal_column("base"))
|
||||
elif order_by == ModelRecordOrderBy.Name:
|
||||
stmt = stmt.order_by(literal_column("name"))
|
||||
elif order_by == ModelRecordOrderBy.Format:
|
||||
stmt = stmt.order_by(literal_column("format"))
|
||||
|
||||
rows = session.exec(stmt).all()
|
||||
# Extract config strings while still in the session
|
||||
config_strings = [row.config for row in rows]
|
||||
|
||||
results: list[AnyModelConfig] = []
|
||||
for config_str in config_strings:
|
||||
try:
|
||||
model_config = ModelConfigFactory.from_dict(json.loads(config_str))
|
||||
except pydantic.ValidationError as e:
|
||||
config_preview = f"{config_str[:64]}..." if len(config_str) > 64 else config_str
|
||||
try:
|
||||
name = json.loads(config_str).get("name", "<unknown>")
|
||||
except Exception:
|
||||
name = "<unknown>"
|
||||
self._logger.warning(
|
||||
f"Skipping invalid model config in the database with name {name}. ({config_preview})"
|
||||
)
|
||||
self._logger.warning(f"Validation error: {e}")
|
||||
else:
|
||||
results.append(model_config)
|
||||
|
||||
return results
|
||||
|
||||
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(ModelTable).where(literal_column("path") == str(path))
|
||||
rows = session.exec(stmt).all()
|
||||
configs = [r.config for r in rows]
|
||||
return [ModelConfigFactory.from_dict(json.loads(c)) for c in configs]
|
||||
|
||||
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(ModelTable).where(literal_column("hash") == hash)
|
||||
rows = session.exec(stmt).all()
|
||||
configs = [r.config for r in rows]
|
||||
return [ModelConfigFactory.from_dict(json.loads(c)) for c in configs]
|
||||
|
||||
def list_models(
|
||||
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
|
||||
) -> PaginatedResults[ModelSummary]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Total count
|
||||
count_stmt = select(func.count()).select_from(ModelTable)
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Data query
|
||||
stmt = select(ModelTable)
|
||||
if order_by == ModelRecordOrderBy.Default:
|
||||
stmt = stmt.order_by(
|
||||
literal_column("type"),
|
||||
literal_column("base"),
|
||||
literal_column("name"),
|
||||
literal_column("format"),
|
||||
)
|
||||
elif order_by == ModelRecordOrderBy.Type:
|
||||
stmt = stmt.order_by(literal_column("type"))
|
||||
elif order_by == ModelRecordOrderBy.Base:
|
||||
stmt = stmt.order_by(literal_column("base"))
|
||||
elif order_by == ModelRecordOrderBy.Name:
|
||||
stmt = stmt.order_by(literal_column("name"))
|
||||
elif order_by == ModelRecordOrderBy.Format:
|
||||
stmt = stmt.order_by(literal_column("format"))
|
||||
|
||||
stmt = stmt.limit(per_page).offset(page * per_page)
|
||||
rows = session.exec(stmt).all()
|
||||
configs = [r.config for r in rows]
|
||||
|
||||
items = [ModelSummary.model_validate({"config": c}) for c in configs]
|
||||
return PaginatedResults(
|
||||
page=page,
|
||||
pages=ceil(total / per_page),
|
||||
per_page=per_page,
|
||||
total=total,
|
||||
items=items,
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
|
||||
ModelRelationshipRecordStorageBase,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.models import ModelRelationshipTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
class SqlModelModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
if model_key_1 == model_key_2:
|
||||
raise ValueError("Cannot relate a model to itself.")
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(ModelRelationshipTable, (a, b))
|
||||
if existing is None:
|
||||
session.add(ModelRelationshipTable(model_key_1=a, model_key_2=b))
|
||||
|
||||
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
|
||||
a, b = sorted([model_key_1, model_key_2])
|
||||
with self._db.get_session() as session:
|
||||
existing = session.get(ModelRelationshipTable, (a, b))
|
||||
if existing is not None:
|
||||
session.delete(existing)
|
||||
|
||||
def get_related_model_keys(self, model_key: str) -> list[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
# Get keys where model_key appears in either column
|
||||
stmt1 = select(ModelRelationshipTable.model_key_2).where(
|
||||
col(ModelRelationshipTable.model_key_1) == model_key
|
||||
)
|
||||
stmt2 = select(ModelRelationshipTable.model_key_1).where(
|
||||
col(ModelRelationshipTable.model_key_2) == model_key
|
||||
)
|
||||
results1 = session.exec(stmt1).all()
|
||||
results2 = session.exec(stmt2).all()
|
||||
return list(set(results1 + results2))
|
||||
|
||||
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt1 = select(ModelRelationshipTable.model_key_2).where(
|
||||
col(ModelRelationshipTable.model_key_1).in_(model_keys)
|
||||
)
|
||||
stmt2 = select(ModelRelationshipTable.model_key_1).where(
|
||||
col(ModelRelationshipTable.model_key_2).in_(model_keys)
|
||||
)
|
||||
results1 = session.exec(stmt1).all()
|
||||
results2 = session.exec(stmt2).all()
|
||||
return list(set(results1 + results2))
|
||||
843
invokeai/app/services/session_queue/session_queue_sqlmodel.py
Normal file
843
invokeai/app/services/session_queue/session_queue_sqlmodel.py
Normal file
@@ -0,0 +1,843 @@
|
||||
"""SQLModel-backed implementation of the session queue service.
|
||||
|
||||
This module is the Phase 3 sibling of `session_queue_sqlite.py`. It uses
|
||||
SQLAlchemy Core for the hot paths (bulk enqueue/cancel/delete, dequeue, list
|
||||
with cursor pagination, aggregations) and keeps the same external behaviour as
|
||||
the raw-SQL implementation, including reliance on the existing DB triggers for
|
||||
`started_at`, `completed_at` and `updated_at`.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic_core import to_jsonable_python
|
||||
from sqlalchemy import and_, delete, func, insert, or_, select, update
|
||||
from sqlalchemy.engine import Row
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
DEFAULT_QUEUE_ID,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelAllExceptCurrentResult,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
ItemIdsResult,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import SessionQueueTable, UserTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
_TERMINAL_STATUSES: tuple[str, ...] = ("completed", "failed", "canceled")
|
||||
|
||||
_QUEUE_COLUMNS = (
|
||||
SessionQueueTable.item_id,
|
||||
SessionQueueTable.batch_id,
|
||||
SessionQueueTable.queue_id,
|
||||
SessionQueueTable.session_id,
|
||||
SessionQueueTable.field_values,
|
||||
SessionQueueTable.session,
|
||||
SessionQueueTable.status,
|
||||
SessionQueueTable.priority,
|
||||
SessionQueueTable.error_traceback,
|
||||
SessionQueueTable.created_at,
|
||||
SessionQueueTable.updated_at,
|
||||
SessionQueueTable.started_at,
|
||||
SessionQueueTable.completed_at,
|
||||
SessionQueueTable.error_type,
|
||||
SessionQueueTable.error_message,
|
||||
SessionQueueTable.origin,
|
||||
SessionQueueTable.destination,
|
||||
SessionQueueTable.retried_from_item_id,
|
||||
SessionQueueTable.user_id,
|
||||
)
|
||||
|
||||
|
||||
def _row_to_queue_item_dict(row: Row) -> dict[str, Any]:
|
||||
"""Convert a Row produced by `_select_queue_item_with_user` to a plain dict
|
||||
that `SessionQueueItem.queue_item_from_dict` expects."""
|
||||
mapping = dict(row._mapping)
|
||||
# Stringify datetime columns so the Pydantic union (`datetime | str`) accepts them
|
||||
# consistently across queries that JOIN datetime columns from multiple tables.
|
||||
for ts_key in ("created_at", "updated_at", "started_at", "completed_at"):
|
||||
ts_value = mapping.get(ts_key)
|
||||
if ts_value is not None and not isinstance(ts_value, str):
|
||||
mapping[ts_key] = str(ts_value)
|
||||
mapping.setdefault("user_display_name", None)
|
||||
mapping.setdefault("user_email", None)
|
||||
mapping.setdefault("workflow", None)
|
||||
return mapping
|
||||
|
||||
|
||||
def _select_queue_item_with_user():
|
||||
"""Build a SELECT that mirrors `sq.*, u.display_name, u.email` with LEFT JOIN."""
|
||||
return (
|
||||
select(
|
||||
*_QUEUE_COLUMNS,
|
||||
SessionQueueTable.workflow,
|
||||
UserTable.display_name.label("user_display_name"),
|
||||
UserTable.email.label("user_email"),
|
||||
)
|
||||
.select_from(SessionQueueTable)
|
||||
.join(UserTable, SessionQueueTable.user_id == UserTable.user_id, isouter=True)
|
||||
)
|
||||
|
||||
|
||||
def _value_tuple_to_dict(t: ValueToInsertTuple) -> dict[str, Any]:
|
||||
"""Adapt the positional tuple from `prepare_values_to_insert` to a dict that
|
||||
SQLAlchemy Core's `insert(...).values([...])` expects."""
|
||||
return {
|
||||
"queue_id": t[0],
|
||||
"session": t[1],
|
||||
"session_id": t[2],
|
||||
"batch_id": t[3],
|
||||
"field_values": t[4],
|
||||
"priority": t[5],
|
||||
"workflow": t[6],
|
||||
"origin": t[7],
|
||||
"destination": t[8],
|
||||
"retried_from_item_id": t[9],
|
||||
"user_id": t[10],
|
||||
}
|
||||
|
||||
|
||||
class SqlModelSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
config = self.__invoker.services.configuration
|
||||
if config.clear_queue_on_startup:
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
return
|
||||
|
||||
if config.max_queue_history is not None:
|
||||
deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history)
|
||||
if deleted > 0:
|
||||
self.__invoker.services.logger.info(
|
||||
f"Pruned {deleted} completed/failed/canceled queue items "
|
||||
f"(kept up to {config.max_queue_history})"
|
||||
)
|
||||
|
||||
# region: internal helpers
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""Sets all in_progress queue items to canceled. Run on app startup."""
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.status == "in_progress")
|
||||
.values(status="canceled")
|
||||
)
|
||||
|
||||
def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int:
|
||||
"""Prune terminal items (completed/failed/canceled) to keep at most N most-recent items."""
|
||||
terminal_filter = and_(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
|
||||
)
|
||||
# Subquery: ids of the items we want to keep (most recent N)
|
||||
keep_ids_stmt = (
|
||||
select(SessionQueueTable.item_id)
|
||||
.where(terminal_filter)
|
||||
.order_by(
|
||||
func.coalesce(
|
||||
SessionQueueTable.completed_at,
|
||||
SessionQueueTable.updated_at,
|
||||
SessionQueueTable.created_at,
|
||||
).desc(),
|
||||
SessionQueueTable.item_id.desc(),
|
||||
)
|
||||
.limit(keep)
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(terminal_filter)
|
||||
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
|
||||
)
|
||||
count = session.execute(count_stmt).scalar_one()
|
||||
session.execute(
|
||||
delete(SessionQueueTable)
|
||||
.where(terminal_filter)
|
||||
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
|
||||
)
|
||||
return int(count)
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items."""
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
).scalar_one()
|
||||
return int(count)
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue."""
|
||||
with self._db.get_readonly_session() as session:
|
||||
priority = session.execute(
|
||||
select(func.max(SessionQueueTable.priority)).where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
).scalar()
|
||||
return int(priority) if priority is not None else 0
|
||||
|
||||
# endregion
|
||||
|
||||
# region: enqueue / dequeue / read single
|
||||
|
||||
async def enqueue_batch(
|
||||
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
|
||||
) -> EnqueueBatchResult:
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(calc_session_count, batch=batch)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
user_id=user_id,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
if values_to_insert:
|
||||
session.execute(
|
||||
insert(SessionQueueTable),
|
||||
[_value_tuple_to_dict(v) for v in values_to_insert],
|
||||
)
|
||||
item_ids_rows = session.execute(
|
||||
select(SessionQueueTable.item_id)
|
||||
.where(SessionQueueTable.batch_id == batch.batch_id)
|
||||
.order_by(SessionQueueTable.item_id.desc())
|
||||
).all()
|
||||
item_ids = [row[0] for row in item_ids_rows]
|
||||
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
item_ids=item_ids,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(SessionQueueTable.status == "pending")
|
||||
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc())
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
return self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.created_at.asc())
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "in_progress",
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user().where(SessionQueueTable.item_id == item_id)
|
||||
).first()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: status mutation
|
||||
|
||||
def _set_queue_item_status(
|
||||
self,
|
||||
item_id: int,
|
||||
status: QUEUE_ITEM_STATUS,
|
||||
error_type: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
with self._db.get_session() as session:
|
||||
current_status = session.execute(
|
||||
select(SessionQueueTable.status).where(SessionQueueTable.item_id == item_id)
|
||||
).scalar()
|
||||
if current_status is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in _TERMINAL_STATUSES:
|
||||
# No update; fall through to fetch + return below.
|
||||
pass
|
||||
else:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.item_id == item_id)
|
||||
.values(
|
||||
status=status,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
)
|
||||
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
|
||||
# If we did not update, do not emit a status change event.
|
||||
if current_status not in _TERMINAL_STATUSES:
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
return queue_item
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
|
||||
def fail_queue_item(
|
||||
self,
|
||||
item_id: int,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
try:
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
with self._db.get_session() as session:
|
||||
session.execute(delete(SessionQueueTable).where(SessionQueueTable.item_id == item_id))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session_state: GraphExecutionState) -> SessionQueueItem:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause
|
||||
# validation errors when the graph is loaded. Graph execution occurs purely in memory - the
|
||||
# session saved here is not referenced during execution.
|
||||
session_json = session_state.model_dump_json(warnings=False, exclude_none=True)
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.item_id == item_id)
|
||||
.values(session=session_json)
|
||||
)
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: simple status checks
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
).scalar_one()
|
||||
return IsEmptyResult(is_empty=int(count) == 0)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
).scalar_one()
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
return IsFullResult(is_full=int(count) >= max_queue_size)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: bulk delete
|
||||
|
||||
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
|
||||
where = [SessionQueueTable.queue_id == queue_id]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=int(count))
|
||||
|
||||
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return PruneResult(deleted=int(count))
|
||||
|
||||
def delete_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> DeleteByDestinationResult:
|
||||
# Handle current in-progress item BEFORE opening a write session of our own,
|
||||
# to avoid nested writes on the single StaticPool connection.
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.destination == destination,
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return DeleteByDestinationResult(deleted=int(count))
|
||||
|
||||
def delete_all_except_current(
|
||||
self, queue_id: str, user_id: Optional[str] = None
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return DeleteAllExceptCurrentResult(deleted=int(count))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: bulk cancel
|
||||
|
||||
def _cancel_skip_in_progress_filter(
|
||||
self, queue_id: str, user_id: Optional[str], extra: list
|
||||
) -> list:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
where.extend(extra)
|
||||
return where
|
||||
|
||||
def cancel_by_batch_ids(
|
||||
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
|
||||
) -> CancelByBatchIDsResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = self._cancel_skip_in_progress_filter(
|
||||
queue_id, user_id, [SessionQueueTable.batch_id.in_(batch_ids)]
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
# Handle current item separately - check ownership if user_id is provided
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByBatchIDsResult(canceled=int(count))
|
||||
|
||||
def cancel_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> CancelByDestinationResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = self._cancel_skip_in_progress_filter(
|
||||
queue_id, user_id, [SessionQueueTable.destination == destination]
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByDestinationResult(canceled=int(count))
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
|
||||
]
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByQueueIDResult(canceled=int(count))
|
||||
|
||||
def cancel_all_except_current(
|
||||
self, queue_id: str, user_id: Optional[str] = None
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
return CancelAllExceptCurrentResult(canceled=int(count))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: list / pagination
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
# NOTE: this preserves the (somewhat surprising) cursor semantics of the original
|
||||
# raw-SQL implementation, including the unparenthesised `AND ... OR ...` precedence.
|
||||
item_id = cursor
|
||||
|
||||
stmt = select(*_QUEUE_COLUMNS, SessionQueueTable.workflow).where(
|
||||
SessionQueueTable.queue_id == queue_id
|
||||
)
|
||||
if status is not None:
|
||||
stmt = stmt.where(SessionQueueTable.status == status)
|
||||
if destination is not None:
|
||||
stmt = stmt.where(SessionQueueTable.destination == destination)
|
||||
if item_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
SessionQueueTable.priority < priority,
|
||||
and_(
|
||||
SessionQueueTable.priority == priority,
|
||||
SessionQueueTable.item_id > item_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
stmt = stmt.order_by(
|
||||
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
|
||||
).limit(limit + 1)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
items = [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
items.pop()
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
stmt = _select_queue_item_with_user().where(SessionQueueTable.queue_id == queue_id)
|
||||
if destination is not None:
|
||||
stmt = stmt.where(SessionQueueTable.destination == destination)
|
||||
stmt = stmt.order_by(
|
||||
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
|
||||
)
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
return [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
|
||||
|
||||
def get_queue_item_ids(
|
||||
self,
|
||||
queue_id: str,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ItemIdsResult:
|
||||
stmt = select(SessionQueueTable.item_id).where(SessionQueueTable.queue_id == queue_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
if order_dir == SQLiteDirection.Descending:
|
||||
stmt = stmt.order_by(SessionQueueTable.created_at.desc())
|
||||
else:
|
||||
stmt = stmt.order_by(SessionQueueTable.created_at.asc())
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
item_ids = [row[0] for row in rows]
|
||||
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: aggregations
|
||||
|
||||
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
|
||||
stmt = (
|
||||
select(SessionQueueTable.status, func.count())
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
|
||||
# For non-admin users, hide current item details if they don't own it
|
||||
show_current_item = current_item is not None and (
|
||||
user_id is None or current_item.user_id == user_id
|
||||
)
|
||||
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
item_id=current_item.item_id if show_current_item else None,
|
||||
session_id=current_item.session_id if show_current_item else None,
|
||||
batch_id=current_item.batch_id if show_current_item else None,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_batch_status(
|
||||
self, queue_id: str, batch_id: str, user_id: Optional[str] = None
|
||||
) -> BatchStatus:
|
||||
stmt = (
|
||||
select(
|
||||
SessionQueueTable.status,
|
||||
func.count(),
|
||||
SessionQueueTable.origin,
|
||||
SessionQueueTable.destination,
|
||||
)
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.batch_id == batch_id,
|
||||
)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
origin = rows[0][2] if rows else None
|
||||
destination = rows[0][3] if rows else None
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_counts_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> SessionQueueCountsByDestination:
|
||||
stmt = (
|
||||
select(SessionQueueTable.status, func.count())
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.destination == destination,
|
||||
)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
|
||||
return SessionQueueCountsByDestination(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: retry
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
for item_id in item_ids:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ("failed", "canceled"):
|
||||
continue
|
||||
retried_item_ids.append(item_id)
|
||||
|
||||
field_values_json = (
|
||||
json.dumps(queue_item.field_values, default=to_jsonable_python)
|
||||
if queue_item.field_values
|
||||
else None
|
||||
)
|
||||
workflow_json = (
|
||||
json.dumps(queue_item.workflow, default=to_jsonable_python)
|
||||
if queue_item.workflow
|
||||
else None
|
||||
)
|
||||
cloned_session = GraphExecutionState(graph=queue_item.session.graph)
|
||||
cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True)
|
||||
|
||||
retried_from_item_id = (
|
||||
queue_item.retried_from_item_id
|
||||
if queue_item.retried_from_item_id is not None
|
||||
else queue_item.item_id
|
||||
)
|
||||
|
||||
values_to_insert.append(
|
||||
(
|
||||
queue_item.queue_id,
|
||||
cloned_session_json,
|
||||
cloned_session.id,
|
||||
queue_item.batch_id,
|
||||
field_values_json,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
queue_item.origin,
|
||||
queue_item.destination,
|
||||
retried_from_item_id,
|
||||
queue_item.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(psyche): Handle max queue size?
|
||||
if values_to_insert:
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
insert(SessionQueueTable),
|
||||
[_value_tuple_to_dict(v) for v in values_to_insert],
|
||||
)
|
||||
|
||||
retry_result = RetryItemsResult(queue_id=queue_id, retried_item_ids=retried_item_ids)
|
||||
self.__invoker.services.events.emit_queue_items_retried(retry_result)
|
||||
return retry_result
|
||||
|
||||
# endregion
|
||||
252
invokeai/app/services/shared/sqlite/models.py
Normal file
252
invokeai/app/services/shared/sqlite/models.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""SQLModel table definitions for the InvokeAI database.
|
||||
|
||||
These models mirror the schema created by the raw SQL migrations.
|
||||
The migrations remain the source of truth for schema changes —
|
||||
these models are used only for querying via SQLModel/SQLAlchemy.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Column, String
|
||||
from sqlalchemy.schema import FetchedValue
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
# --- boards ---
|
||||
|
||||
|
||||
class BoardTable(SQLModel, table=True):
|
||||
"""Mirrors the `boards` table."""
|
||||
|
||||
__tablename__ = "boards"
|
||||
|
||||
board_id: str = Field(primary_key=True)
|
||||
board_name: str
|
||||
cover_image_name: Optional[str] = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
deleted_at: Optional[datetime] = Field(default=None)
|
||||
archived: bool = Field(default=False)
|
||||
user_id: str = Field(default="system")
|
||||
is_public: bool = Field(default=False)
|
||||
board_visibility: str = Field(default="private")
|
||||
|
||||
|
||||
class BoardImageTable(SQLModel, table=True):
|
||||
"""Mirrors the `board_images` junction table."""
|
||||
|
||||
__tablename__ = "board_images"
|
||||
|
||||
image_name: str = Field(primary_key=True)
|
||||
board_id: str = Field(foreign_key="boards.board_id")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
deleted_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
||||
class SharedBoardTable(SQLModel, table=True):
|
||||
"""Mirrors the `shared_boards` table."""
|
||||
|
||||
__tablename__ = "shared_boards"
|
||||
|
||||
board_id: str = Field(primary_key=True, foreign_key="boards.board_id")
|
||||
user_id: str = Field(primary_key=True, foreign_key="users.user_id")
|
||||
can_edit: bool = Field(default=False)
|
||||
shared_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# --- images ---
|
||||
|
||||
|
||||
class ImageTable(SQLModel, table=True):
|
||||
"""Mirrors the `images` table."""
|
||||
|
||||
__tablename__ = "images"
|
||||
|
||||
image_name: str = Field(primary_key=True)
|
||||
image_origin: str
|
||||
image_category: str
|
||||
width: int
|
||||
height: int
|
||||
session_id: Optional[str] = Field(default=None)
|
||||
node_id: Optional[str] = Field(default=None)
|
||||
metadata_: Optional[str] = Field(default=None, sa_column_kwargs={"name": "metadata"})
|
||||
is_intermediate: bool = Field(default=False)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
deleted_at: Optional[datetime] = Field(default=None)
|
||||
starred: bool = Field(default=False)
|
||||
has_workflow: bool = Field(default=False)
|
||||
user_id: str = Field(default="system")
|
||||
|
||||
|
||||
# --- workflows ---
|
||||
|
||||
|
||||
class WorkflowLibraryTable(SQLModel, table=True):
|
||||
"""Mirrors the `workflow_library` table."""
|
||||
|
||||
__tablename__ = "workflow_library"
|
||||
|
||||
workflow_id: str = Field(primary_key=True)
|
||||
workflow: str # JSON blob
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
opened_at: Optional[datetime] = Field(default=None)
|
||||
# Generated columns — server-side, excluded from INSERT/UPDATE
|
||||
category: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
|
||||
name: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
|
||||
description: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
|
||||
tags: Optional[str] = Field(default=None, sa_column=Column(String, FetchedValue(), server_default=None))
|
||||
user_id: str = Field(default="system")
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
|
||||
class WorkflowImageTable(SQLModel, table=True):
|
||||
"""Mirrors the `workflow_images` junction table."""
|
||||
|
||||
__tablename__ = "workflow_images"
|
||||
|
||||
image_name: str = Field(primary_key=True, foreign_key="images.image_name")
|
||||
workflow_id: str = Field(foreign_key="workflow_library.workflow_id")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
deleted_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
||||
# --- session queue ---
|
||||
|
||||
|
||||
class SessionQueueTable(SQLModel, table=True):
|
||||
"""Mirrors the `session_queue` table."""
|
||||
|
||||
__tablename__ = "session_queue"
|
||||
|
||||
item_id: Optional[int] = Field(default=None, primary_key=True) # AUTOINCREMENT
|
||||
batch_id: str
|
||||
queue_id: str
|
||||
session_id: str = Field(unique=True)
|
||||
field_values: Optional[str] = Field(default=None)
|
||||
session: str # JSON blob
|
||||
status: str = Field(default="pending")
|
||||
priority: int = Field(default=0)
|
||||
error_traceback: Optional[str] = Field(default=None)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
started_at: Optional[datetime] = Field(default=None)
|
||||
completed_at: Optional[datetime] = Field(default=None)
|
||||
error_type: Optional[str] = Field(default=None)
|
||||
error_message: Optional[str] = Field(default=None)
|
||||
origin: Optional[str] = Field(default=None)
|
||||
destination: Optional[str] = Field(default=None)
|
||||
retried_from_item_id: Optional[int] = Field(default=None)
|
||||
user_id: str = Field(default="system")
|
||||
workflow: Optional[str] = Field(default=None) # JSON blob
|
||||
|
||||
|
||||
# --- models ---
|
||||
|
||||
|
||||
class ModelTable(SQLModel, table=True):
|
||||
"""Mirrors the `models` table.
|
||||
|
||||
Most columns are GENERATED ALWAYS from the `config` JSON blob.
|
||||
We define them here for read access but they should not be set directly.
|
||||
"""
|
||||
|
||||
__tablename__ = "models"
|
||||
|
||||
id: str = Field(primary_key=True)
|
||||
config: str # JSON blob — all model metadata is extracted from this via GENERATED ALWAYS columns
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
# NOTE: The `models` table has many GENERATED ALWAYS columns (hash, base, type, path, format, name, etc.)
|
||||
# that are automatically extracted from the `config` JSON blob by SQLite.
|
||||
# We intentionally do NOT define them here because SQLAlchemy would try to include them in
|
||||
# INSERT/UPDATE statements, which fails on GENERATED columns.
|
||||
# To query by these columns, use raw text filters or the `text()` function.
|
||||
# The ModelRecordServiceSqlModel extracts all needed data from the `config` JSON blob directly.
|
||||
|
||||
|
||||
class ModelManagerMetadataTable(SQLModel, table=True):
|
||||
"""Mirrors the `model_manager_metadata` table."""
|
||||
|
||||
__tablename__ = "model_manager_metadata"
|
||||
|
||||
metadata_key: str = Field(primary_key=True)
|
||||
metadata_value: str
|
||||
|
||||
|
||||
class ModelRelationshipTable(SQLModel, table=True):
|
||||
"""Mirrors the `model_relationships` table."""
|
||||
|
||||
__tablename__ = "model_relationships"
|
||||
|
||||
model_key_1: str = Field(primary_key=True)
|
||||
model_key_2: str = Field(primary_key=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# --- style presets ---
|
||||
|
||||
|
||||
class StylePresetTable(SQLModel, table=True):
|
||||
"""Mirrors the `style_presets` table."""
|
||||
|
||||
__tablename__ = "style_presets"
|
||||
|
||||
id: str = Field(primary_key=True)
|
||||
name: str
|
||||
preset_data: str # JSON blob
|
||||
type: str = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
user_id: str = Field(default="system")
|
||||
is_public: bool = Field(default=False)
|
||||
|
||||
|
||||
# --- users & auth ---
|
||||
|
||||
|
||||
class UserTable(SQLModel, table=True):
|
||||
"""Mirrors the `users` table."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id: str = Field(primary_key=True)
|
||||
email: str = Field(unique=True)
|
||||
display_name: Optional[str] = Field(default=None)
|
||||
password_hash: str
|
||||
is_admin: bool = Field(default=False)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
last_login_at: Optional[datetime] = Field(default=None)
|
||||
|
||||
|
||||
# --- app settings ---
|
||||
|
||||
|
||||
class AppSettingTable(SQLModel, table=True):
|
||||
"""Mirrors the `app_settings` table."""
|
||||
|
||||
__tablename__ = "app_settings"
|
||||
|
||||
key: str = Field(primary_key=True)
|
||||
value: str
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# --- client state ---
|
||||
|
||||
|
||||
class ClientStateTable(SQLModel, table=True):
|
||||
"""Mirrors the `client_state` table."""
|
||||
|
||||
__tablename__ = "client_state"
|
||||
|
||||
user_id: str = Field(primary_key=True, foreign_key="users.user_id")
|
||||
key: str = Field(primary_key=True)
|
||||
value: str
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
@@ -4,6 +4,11 @@ from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from logging import Logger
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlmodel import Session, create_engine
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import sqlite_memory
|
||||
|
||||
@@ -25,6 +30,7 @@ class SqliteDatabase:
|
||||
- `conn`: A `sqlite3.Connection` object. Note that the connection must never be closed if the database is in-memory.
|
||||
- `lock`: A shared re-entrant lock, used to approximate thread safety.
|
||||
- `clean()`: Runs the SQL `VACUUM;` command and reports on the freed space.
|
||||
- `get_session()`: Returns a SQLModel Session for ORM-based queries.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
|
||||
@@ -55,6 +61,54 @@ class SqliteDatabase:
|
||||
# Set a busy timeout to prevent database lockups during writes
|
||||
self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
|
||||
|
||||
# Set up the SQLAlchemy engine for SQLModel-based queries.
|
||||
# For file-based DBs, both connections point to the same file.
|
||||
# For in-memory DBs, we use a named shared cache so both connections
|
||||
# see the same database.
|
||||
if self._db_path:
|
||||
db_uri = f"sqlite:///{self._db_path}"
|
||||
# StaticPool reuses a single connection — ideal for SQLite which
|
||||
# serializes writes anyway. Avoids the overhead of creating a new
|
||||
# connection for every Session.
|
||||
self._engine = create_engine(
|
||||
db_uri,
|
||||
echo=self._verbose,
|
||||
connect_args={"check_same_thread": False},
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
else:
|
||||
# Use a shared in-memory database via URI with shared cache.
|
||||
# The raw sqlite3 connection above already created ":memory:",
|
||||
# so we re-create it with the shared URI instead.
|
||||
shared_uri = f"file:invokeai_memdb_{uuid4().hex}?mode=memory&cache=shared"
|
||||
self._conn.close()
|
||||
self._conn = sqlite3.connect(shared_uri, uri=True, check_same_thread=False)
|
||||
self._conn.row_factory = sqlite3.Row
|
||||
if self._verbose:
|
||||
self._conn.set_trace_callback(self._logger.debug)
|
||||
self._conn.execute("PRAGMA foreign_keys = ON;")
|
||||
|
||||
self._engine = create_engine(
|
||||
"sqlite+pysqlite://",
|
||||
echo=self._verbose,
|
||||
creator=lambda: sqlite3.connect(shared_uri, uri=True, check_same_thread=False),
|
||||
poolclass=StaticPool,
|
||||
)
|
||||
|
||||
# Apply the same PRAGMAs to all SQLAlchemy connections
|
||||
@event.listens_for(self._engine, "connect")
|
||||
def _set_sqlite_pragmas(dbapi_connection, connection_record): # type: ignore
|
||||
cursor = dbapi_connection.cursor()
|
||||
# Note: We intentionally skip PRAGMA foreign_keys for the SQLAlchemy engine.
|
||||
# Migration 22 renames the `models` table which corrupts FK references in
|
||||
# `model_relationships`. The raw sqlite3 connection already enforces FKs
|
||||
# for the migration phase. The SQLAlchemy engine is used only for queries
|
||||
# after migrations are complete.
|
||||
if self._db_path:
|
||||
cursor.execute("PRAGMA journal_mode = WAL;")
|
||||
cursor.execute("PRAGMA busy_timeout = 5000;")
|
||||
cursor.close()
|
||||
|
||||
def clean(self) -> None:
|
||||
"""
|
||||
Cleans the database by running the VACUUM command, reporting on the freed space.
|
||||
@@ -75,6 +129,36 @@ class SqliteDatabase:
|
||||
self._logger.error(f"Error cleaning database: {e}")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def get_session(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager that yields a SQLModel Session for write operations.
|
||||
Commits on success, rolls back on exception.
|
||||
|
||||
Uses expire_on_commit=False so that model attributes remain accessible
|
||||
after commit without triggering lazy-loads or DetachedInstanceError.
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def get_readonly_session(self) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Context manager that yields a lightweight read-only SQLModel Session.
|
||||
|
||||
Optimized for SELECT queries:
|
||||
- autoflush=False: skips the automatic flush before every query
|
||||
- no commit/rollback: avoids transaction overhead for reads
|
||||
- expire_on_commit=False: attributes stay accessible after close
|
||||
"""
|
||||
with Session(self._engine, expire_on_commit=False, autoflush=False) as session:
|
||||
yield session
|
||||
|
||||
@contextmanager
|
||||
def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,115 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.sqlite.models import StylePresetTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_base import StylePresetRecordsStorageBase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetType,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetRecordDTO,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
def _to_dto(row: StylePresetTable) -> StylePresetRecordDTO:
|
||||
return StylePresetRecordDTO.from_dict(
|
||||
{
|
||||
"id": row.id,
|
||||
"name": row.name,
|
||||
"preset_data": row.preset_data,
|
||||
"type": row.type,
|
||||
"created_at": str(row.created_at),
|
||||
"updated_at": str(row.updated_at),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SqlModelStylePresetRecordsStorage(StylePresetRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_style_presets()
|
||||
|
||||
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(StylePresetTable, style_preset_id)
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
return _to_dto(row)
|
||||
|
||||
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
|
||||
style_preset_id = uuid_string()
|
||||
row = StylePresetTable(
|
||||
id=style_preset_id,
|
||||
name=style_preset.name,
|
||||
preset_data=style_preset.preset_data.model_dump_json(),
|
||||
type=style_preset.type,
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
session.add(row)
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
|
||||
with self._db.get_session() as session:
|
||||
for style_preset in style_presets:
|
||||
row = StylePresetTable(
|
||||
id=uuid_string(),
|
||||
name=style_preset.name,
|
||||
preset_data=style_preset.preset_data.model_dump_json(),
|
||||
type=style_preset.type,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(StylePresetTable, style_preset_id)
|
||||
if row is None:
|
||||
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
|
||||
|
||||
if changes.name is not None:
|
||||
row.name = changes.name
|
||||
if changes.preset_data is not None:
|
||||
row.preset_data = changes.preset_data.model_dump_json()
|
||||
|
||||
session.add(row)
|
||||
return self.get(style_preset_id)
|
||||
|
||||
def delete(self, style_preset_id: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(StylePresetTable, style_preset_id)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
|
||||
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(StylePresetTable)
|
||||
if type is not None:
|
||||
stmt = stmt.where(col(StylePresetTable.type) == type)
|
||||
stmt = stmt.order_by(col(StylePresetTable.name).asc())
|
||||
rows = session.exec(stmt).all()
|
||||
return [_to_dto(r) for r in rows]
|
||||
|
||||
def _sync_default_style_presets(self) -> None:
|
||||
"""Syncs default style presets to the database."""
|
||||
# Delete existing defaults
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(StylePresetTable).where(col(StylePresetTable.type) == "default")
|
||||
rows = session.exec(stmt).all()
|
||||
for row in rows:
|
||||
session.delete(row)
|
||||
|
||||
# Re-create from file
|
||||
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
|
||||
presets = json.load(file)
|
||||
for preset in presets:
|
||||
style_preset = StylePresetWithoutId.model_validate(preset)
|
||||
self.create(style_preset)
|
||||
191
invokeai/app/services/users/users_sqlmodel.py
Normal file
191
invokeai/app/services/users/users_sqlmodel.py
Normal file
@@ -0,0 +1,191 @@
|
||||
"""SQLModel implementation of user service."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.auth.password_utils import hash_password, validate_password_strength, verify_password
|
||||
from invokeai.app.services.shared.sqlite.models import UserTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_base import UserServiceBase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest, UserDTO, UserUpdateRequest
|
||||
|
||||
|
||||
def _to_dto(row: UserTable) -> UserDTO:
|
||||
return UserDTO(
|
||||
user_id=row.user_id,
|
||||
email=row.email,
|
||||
display_name=row.display_name,
|
||||
is_admin=row.is_admin,
|
||||
is_active=row.is_active,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
last_login_at=row.last_login_at,
|
||||
)
|
||||
|
||||
|
||||
class UserServiceSqlModel(UserServiceBase):
|
||||
"""SQLModel-based user service."""
|
||||
|
||||
def __init__(self, db: SqliteDatabase):
|
||||
self._db = db
|
||||
|
||||
def create(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
|
||||
if strict_password_checking:
|
||||
is_valid, error_msg = validate_password_strength(user_data.password)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
elif not user_data.password:
|
||||
raise ValueError("Password cannot be empty")
|
||||
|
||||
if self.get_by_email(user_data.email) is not None:
|
||||
raise ValueError(f"User with email {user_data.email} already exists")
|
||||
|
||||
user_id = str(uuid4())
|
||||
password_hash = hash_password(user_data.password)
|
||||
|
||||
user = UserTable(
|
||||
user_id=user_id,
|
||||
email=user_data.email,
|
||||
display_name=user_data.display_name,
|
||||
password_hash=password_hash,
|
||||
is_admin=user_data.is_admin,
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
session.add(user)
|
||||
|
||||
result = self.get(user_id)
|
||||
if result is None:
|
||||
raise RuntimeError("Failed to retrieve created user")
|
||||
return result
|
||||
|
||||
def get(self, user_id: str) -> UserDTO | None:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(UserTable, user_id)
|
||||
if row is None:
|
||||
return None
|
||||
return _to_dto(row)
|
||||
|
||||
def get_by_email(self, email: str) -> UserDTO | None:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(UserTable).where(col(UserTable.email) == email)
|
||||
row = session.exec(stmt).first()
|
||||
if row is None:
|
||||
return None
|
||||
return _to_dto(row)
|
||||
|
||||
def update(self, user_id: str, changes: UserUpdateRequest, strict_password_checking: bool = True) -> UserDTO:
|
||||
user = self.get(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
if changes.password is not None:
|
||||
if strict_password_checking:
|
||||
is_valid, error_msg = validate_password_strength(changes.password)
|
||||
if not is_valid:
|
||||
raise ValueError(error_msg)
|
||||
elif not changes.password:
|
||||
raise ValueError("Password cannot be empty")
|
||||
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(UserTable, user_id)
|
||||
if row is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
|
||||
if changes.display_name is not None:
|
||||
row.display_name = changes.display_name
|
||||
if changes.password is not None:
|
||||
row.password_hash = hash_password(changes.password)
|
||||
if changes.is_admin is not None:
|
||||
row.is_admin = changes.is_admin
|
||||
if changes.is_active is not None:
|
||||
row.is_active = changes.is_active
|
||||
|
||||
session.add(row)
|
||||
|
||||
updated_user = self.get(user_id)
|
||||
if updated_user is None:
|
||||
raise RuntimeError("Failed to retrieve updated user")
|
||||
return updated_user
|
||||
|
||||
def delete(self, user_id: str) -> None:
|
||||
with self._db.get_session() as session:
|
||||
row = session.get(UserTable, user_id)
|
||||
if row is None:
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
session.delete(row)
|
||||
|
||||
def authenticate(self, email: str, password: str) -> UserDTO | None:
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(UserTable).where(col(UserTable.email) == email)
|
||||
row = session.exec(stmt).first()
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
if not verify_password(password, row.password_hash):
|
||||
return None
|
||||
|
||||
row.last_login_at = datetime.now(timezone.utc)
|
||||
session.add(row)
|
||||
|
||||
return _to_dto(row)
|
||||
|
||||
def has_admin(self) -> bool:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(UserTable)
|
||||
.where(
|
||||
col(UserTable.is_admin) == True, # noqa: E712
|
||||
col(UserTable.is_active) == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
count = session.exec(stmt).one()
|
||||
return count > 0
|
||||
|
||||
def create_admin(self, user_data: UserCreateRequest, strict_password_checking: bool = True) -> UserDTO:
|
||||
if self.has_admin():
|
||||
raise ValueError("Admin user already exists")
|
||||
|
||||
admin_data = UserCreateRequest(
|
||||
email=user_data.email,
|
||||
display_name=user_data.display_name,
|
||||
password=user_data.password,
|
||||
is_admin=True,
|
||||
)
|
||||
return self.create(admin_data, strict_password_checking=strict_password_checking)
|
||||
|
||||
def list_users(self, limit: int = 100, offset: int = 0) -> list[UserDTO]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(UserTable).order_by(col(UserTable.created_at).desc()).limit(limit).offset(offset)
|
||||
rows = session.exec(stmt).all()
|
||||
return [_to_dto(r) for r in rows]
|
||||
|
||||
def get_admin_email(self) -> str | None:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(UserTable)
|
||||
.where(
|
||||
col(UserTable.is_admin) == True, # noqa: E712
|
||||
col(UserTable.is_active) == True, # noqa: E712
|
||||
)
|
||||
.order_by(col(UserTable.created_at).asc())
|
||||
.limit(1)
|
||||
)
|
||||
row = session.exec(stmt).first()
|
||||
return row.email if row else None
|
||||
|
||||
def count_admins(self) -> int:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(UserTable)
|
||||
.where(
|
||||
col(UserTable.is_admin) == True, # noqa: E712
|
||||
col(UserTable.is_active) == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
count = session.exec(stmt).one()
|
||||
return count
|
||||
@@ -0,0 +1,431 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import col, select
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.shared.pagination import PaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import WorkflowLibraryTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.workflow_records.workflow_records_base import WorkflowRecordsStorageBase
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WORKFLOW_LIBRARY_DEFAULT_USER_ID,
|
||||
Workflow,
|
||||
WorkflowCategory,
|
||||
WorkflowNotFoundError,
|
||||
WorkflowRecordDTO,
|
||||
WorkflowRecordListItemDTO,
|
||||
WorkflowRecordListItemDTOValidator,
|
||||
WorkflowRecordOrderBy,
|
||||
WorkflowValidator,
|
||||
WorkflowWithoutID,
|
||||
)
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
||||
|
||||
def _row_to_dto(row: WorkflowLibraryTable) -> WorkflowRecordDTO:
|
||||
return WorkflowRecordDTO.from_dict(
|
||||
{
|
||||
"workflow_id": row.workflow_id,
|
||||
"workflow": row.workflow,
|
||||
"name": row.name,
|
||||
"created_at": str(row.created_at),
|
||||
"updated_at": str(row.updated_at),
|
||||
"opened_at": str(row.opened_at) if row.opened_at else None,
|
||||
"user_id": row.user_id,
|
||||
"is_public": row.is_public,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _row_to_list_item(row: WorkflowLibraryTable) -> WorkflowRecordListItemDTO:
|
||||
return WorkflowRecordListItemDTOValidator.validate_python(
|
||||
{
|
||||
"workflow_id": row.workflow_id,
|
||||
"category": row.category,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"created_at": str(row.created_at),
|
||||
"updated_at": str(row.updated_at),
|
||||
"opened_at": str(row.opened_at) if row.opened_at else None,
|
||||
"tags": row.tags,
|
||||
"user_id": row.user_id,
|
||||
"is_public": row.is_public,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class SqlModelWorkflowRecordsStorage(WorkflowRecordsStorageBase):
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self._invoker = invoker
|
||||
self._sync_default_workflows()
|
||||
|
||||
def get(self, workflow_id: str) -> WorkflowRecordDTO:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.get(WorkflowLibraryTable, workflow_id)
|
||||
if row is None:
|
||||
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
|
||||
return _row_to_dto(row)
|
||||
|
||||
def create(
|
||||
self,
|
||||
workflow: WorkflowWithoutID,
|
||||
user_id: str = WORKFLOW_LIBRARY_DEFAULT_USER_ID,
|
||||
is_public: bool = False,
|
||||
) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be created via this method")
|
||||
|
||||
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
|
||||
row = WorkflowLibraryTable(
|
||||
workflow_id=workflow_with_id.id,
|
||||
workflow=workflow_with_id.model_dump_json(),
|
||||
user_id=user_id,
|
||||
is_public=is_public,
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
session.add(row)
|
||||
return self.get(workflow_with_id.id)
|
||||
|
||||
def update(self, workflow: Workflow, user_id: Optional[str] = None) -> WorkflowRecordDTO:
|
||||
if workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be updated")
|
||||
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(WorkflowLibraryTable).where(
|
||||
col(WorkflowLibraryTable.workflow_id) == workflow.id,
|
||||
col(WorkflowLibraryTable.category) == "user",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
|
||||
|
||||
row = session.exec(stmt).first()
|
||||
if row is not None:
|
||||
row.workflow = workflow.model_dump_json()
|
||||
session.add(row)
|
||||
|
||||
return self.get(workflow.id)
|
||||
|
||||
def delete(self, workflow_id: str, user_id: Optional[str] = None) -> None:
|
||||
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
|
||||
raise ValueError("Default workflows cannot be deleted")
|
||||
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(WorkflowLibraryTable).where(
|
||||
col(WorkflowLibraryTable.workflow_id) == workflow_id,
|
||||
col(WorkflowLibraryTable.category) == "user",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
|
||||
|
||||
row = session.exec(stmt).first()
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
|
||||
def update_is_public(self, workflow_id: str, is_public: bool, user_id: Optional[str] = None) -> WorkflowRecordDTO:
|
||||
record = self.get(workflow_id)
|
||||
workflow = record.workflow
|
||||
|
||||
tags_list = [t.strip() for t in workflow.tags.split(",") if t.strip()] if workflow.tags else []
|
||||
if is_public and "shared" not in tags_list:
|
||||
tags_list.append("shared")
|
||||
elif not is_public and "shared" in tags_list:
|
||||
tags_list.remove("shared")
|
||||
updated_tags = ", ".join(tags_list)
|
||||
updated_workflow = workflow.model_copy(update={"tags": updated_tags})
|
||||
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(WorkflowLibraryTable).where(
|
||||
col(WorkflowLibraryTable.workflow_id) == workflow_id,
|
||||
col(WorkflowLibraryTable.category) == "user",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
|
||||
|
||||
row = session.exec(stmt).first()
|
||||
if row is not None:
|
||||
row.workflow = updated_workflow.model_dump_json()
|
||||
row.is_public = is_public
|
||||
session.add(row)
|
||||
|
||||
return self.get(workflow_id)
|
||||
|
||||
def get_many(
|
||||
self,
|
||||
order_by: WorkflowRecordOrderBy,
|
||||
direction: SQLiteDirection,
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
page: int = 0,
|
||||
per_page: Optional[int] = None,
|
||||
query: Optional[str] = None,
|
||||
tags: Optional[list[str]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
) -> PaginatedResults[WorkflowRecordListItemDTO]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(WorkflowLibraryTable)
|
||||
count_stmt = select(func.count()).select_from(WorkflowLibraryTable)
|
||||
|
||||
# Apply filters to both
|
||||
stmt, count_stmt = self._apply_filters(
|
||||
stmt,
|
||||
count_stmt,
|
||||
categories,
|
||||
query,
|
||||
tags,
|
||||
has_been_opened,
|
||||
user_id,
|
||||
is_public,
|
||||
)
|
||||
|
||||
# Count
|
||||
total = session.exec(count_stmt).one()
|
||||
|
||||
# Ordering
|
||||
order_col = self._get_order_col(order_by)
|
||||
stmt = stmt.order_by(order_col.desc() if direction == SQLiteDirection.Descending else order_col.asc())
|
||||
|
||||
# Pagination
|
||||
if per_page:
|
||||
stmt = stmt.limit(per_page).offset(page * per_page)
|
||||
|
||||
rows = session.exec(stmt).all()
|
||||
workflows = [_row_to_list_item(r) for r in rows]
|
||||
|
||||
if per_page:
|
||||
pages = total // per_page + (total % per_page > 0)
|
||||
else:
|
||||
pages = 1
|
||||
|
||||
return PaginatedResults(
|
||||
items=workflows,
|
||||
page=page,
|
||||
per_page=per_page if per_page else total,
|
||||
pages=pages,
|
||||
total=total,
|
||||
)
|
||||
|
||||
def counts_by_tag(
|
||||
self,
|
||||
tags: list[str],
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
has_been_opened: Optional[bool] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
if not tags:
|
||||
return {}
|
||||
|
||||
result: dict[str, int] = {}
|
||||
with self._db.get_readonly_session() as session:
|
||||
for tag in tags:
|
||||
stmt = select(func.count()).select_from(WorkflowLibraryTable)
|
||||
stmt, _ = self._apply_filters(stmt, stmt, categories, None, None, has_been_opened, user_id, is_public)
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.tags).like(f"%{tag.strip()}%"))
|
||||
count = session.exec(stmt).one()
|
||||
result[tag] = count
|
||||
return result
|
||||
|
||||
def counts_by_category(
|
||||
self,
|
||||
categories: list[WorkflowCategory],
|
||||
has_been_opened: Optional[bool] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
) -> dict[str, int]:
|
||||
result: dict[str, int] = {}
|
||||
with self._db.get_readonly_session() as session:
|
||||
for category in categories:
|
||||
stmt = select(func.count()).select_from(WorkflowLibraryTable)
|
||||
stmt, _ = self._apply_filters(stmt, stmt, categories, None, None, has_been_opened, user_id, is_public)
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.category) == category.value)
|
||||
count = session.exec(stmt).one()
|
||||
result[category.value] = count
|
||||
return result
|
||||
|
||||
def update_opened_at(self, workflow_id: str, user_id: Optional[str] = None) -> None:
|
||||
with self._db.get_session() as session:
|
||||
stmt = select(WorkflowLibraryTable).where(col(WorkflowLibraryTable.workflow_id) == workflow_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.user_id) == user_id)
|
||||
row = session.exec(stmt).first()
|
||||
if row is not None:
|
||||
row.opened_at = datetime.utcnow()
|
||||
session.add(row)
|
||||
|
||||
def get_all_tags(
|
||||
self,
|
||||
categories: Optional[list[WorkflowCategory]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
is_public: Optional[bool] = None,
|
||||
) -> list[str]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
stmt = select(WorkflowLibraryTable.tags).where(
|
||||
col(WorkflowLibraryTable.tags).is_not(None),
|
||||
col(WorkflowLibraryTable.tags) != "",
|
||||
)
|
||||
|
||||
if categories:
|
||||
category_strings = [c.value for c in categories]
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.category).in_(category_strings))
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(
|
||||
(col(WorkflowLibraryTable.user_id) == user_id) | (col(WorkflowLibraryTable.category) == "default")
|
||||
)
|
||||
if is_public is True:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.is_public) == True) # noqa: E712
|
||||
elif is_public is False:
|
||||
stmt = stmt.where(col(WorkflowLibraryTable.is_public) == False) # noqa: E712
|
||||
|
||||
rows = session.exec(stmt).all()
|
||||
|
||||
all_tags: set[str] = set()
|
||||
for tags_value in rows:
|
||||
if tags_value and isinstance(tags_value, str):
|
||||
for tag in tags_value.split(","):
|
||||
tag_stripped = tag.strip()
|
||||
if tag_stripped:
|
||||
all_tags.add(tag_stripped)
|
||||
return sorted(all_tags)
|
||||
|
||||
def _sync_default_workflows(self) -> None:
|
||||
"""Syncs default workflows to the database."""
|
||||
with self._db.get_session() as session:
|
||||
workflows_from_file: list[Workflow] = []
|
||||
workflows_to_update: list[Workflow] = []
|
||||
workflows_to_add: list[Workflow] = []
|
||||
workflows_dir = Path(__file__).parent / Path("default_workflows")
|
||||
workflow_paths = workflows_dir.glob("*.json")
|
||||
|
||||
for path in workflow_paths:
|
||||
bytes_ = path.read_bytes()
|
||||
workflow_from_file = WorkflowValidator.validate_json(bytes_)
|
||||
|
||||
assert workflow_from_file.id.startswith("default_"), (
|
||||
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
|
||||
)
|
||||
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
|
||||
f"Invalid default workflow category: {workflow_from_file.meta.category}"
|
||||
)
|
||||
|
||||
workflows_from_file.append(workflow_from_file)
|
||||
|
||||
try:
|
||||
workflow_from_db = self.get(workflow_from_file.id).workflow
|
||||
if workflow_from_file != workflow_from_db:
|
||||
self._invoker.services.logger.debug(
|
||||
f"Updating library workflow {workflow_from_file.name} ({workflow_from_file.id})"
|
||||
)
|
||||
workflows_to_update.append(workflow_from_file)
|
||||
except WorkflowNotFoundError:
|
||||
self._invoker.services.logger.debug(
|
||||
f"Adding missing default workflow {workflow_from_file.name} ({workflow_from_file.id})"
|
||||
)
|
||||
workflows_to_add.append(workflow_from_file)
|
||||
|
||||
# Delete obsolete defaults
|
||||
library_workflows_from_db = self.get_many(
|
||||
order_by=WorkflowRecordOrderBy.Name,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
categories=[WorkflowCategory.Default],
|
||||
).items
|
||||
workflows_from_file_ids = [w.id for w in workflows_from_file]
|
||||
|
||||
for w in library_workflows_from_db:
|
||||
if w.workflow_id not in workflows_from_file_ids:
|
||||
self._invoker.services.logger.debug(
|
||||
f"Deleting obsolete default workflow {w.name} ({w.workflow_id})"
|
||||
)
|
||||
row = session.get(WorkflowLibraryTable, w.workflow_id)
|
||||
if row is not None:
|
||||
session.delete(row)
|
||||
|
||||
# Add new defaults
|
||||
for w in workflows_to_add:
|
||||
session.add(
|
||||
WorkflowLibraryTable(
|
||||
workflow_id=w.id,
|
||||
workflow=w.model_dump_json(),
|
||||
)
|
||||
)
|
||||
|
||||
# Update changed defaults
|
||||
for w in workflows_to_update:
|
||||
row = session.get(WorkflowLibraryTable, w.id)
|
||||
if row is not None:
|
||||
row.workflow = w.model_dump_json()
|
||||
session.add(row)
|
||||
|
||||
@staticmethod
|
||||
def _apply_filters(stmt, count_stmt, categories, query, tags, has_been_opened, user_id, is_public):
|
||||
"""Apply common filters to both data and count queries."""
|
||||
if categories:
|
||||
category_strings = [c.value for c in categories]
|
||||
cond = col(WorkflowLibraryTable.category).in_(category_strings)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
cond = col(WorkflowLibraryTable.tags).like(f"%{tag.strip()}%")
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if has_been_opened is True:
|
||||
cond = col(WorkflowLibraryTable.opened_at).is_not(None)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
elif has_been_opened is False:
|
||||
cond = col(WorkflowLibraryTable.opened_at).is_(None)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
stripped_query = query.strip() if query else None
|
||||
if stripped_query:
|
||||
wildcard = f"%{stripped_query}%"
|
||||
cond = (
|
||||
col(WorkflowLibraryTable.name).like(wildcard)
|
||||
| col(WorkflowLibraryTable.description).like(wildcard)
|
||||
| col(WorkflowLibraryTable.tags).like(wildcard)
|
||||
)
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if user_id is not None:
|
||||
cond = (col(WorkflowLibraryTable.user_id) == user_id) | (col(WorkflowLibraryTable.category) == "default")
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
if is_public is True:
|
||||
cond = col(WorkflowLibraryTable.is_public) == True # noqa: E712
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
elif is_public is False:
|
||||
cond = col(WorkflowLibraryTable.is_public) == False # noqa: E712
|
||||
stmt = stmt.where(cond)
|
||||
count_stmt = count_stmt.where(cond)
|
||||
|
||||
return stmt, count_stmt
|
||||
|
||||
@staticmethod
|
||||
def _get_order_col(order_by: WorkflowRecordOrderBy):
|
||||
if order_by == WorkflowRecordOrderBy.Name:
|
||||
return col(WorkflowLibraryTable.name)
|
||||
elif order_by == WorkflowRecordOrderBy.Description:
|
||||
return col(WorkflowLibraryTable.description)
|
||||
elif order_by == WorkflowRecordOrderBy.CreatedAt:
|
||||
return col(WorkflowLibraryTable.created_at)
|
||||
elif order_by == WorkflowRecordOrderBy.UpdatedAt:
|
||||
return col(WorkflowLibraryTable.updated_at)
|
||||
elif order_by == WorkflowRecordOrderBy.OpenedAt:
|
||||
return col(WorkflowLibraryTable.opened_at)
|
||||
else:
|
||||
return col(WorkflowLibraryTable.created_at)
|
||||
@@ -61,6 +61,7 @@ dependencies = [
|
||||
"pydantic-settings",
|
||||
"pydantic",
|
||||
"python-socketio",
|
||||
"sqlmodel",
|
||||
"uvicorn[standard]",
|
||||
|
||||
# Auxiliary dependencies, pinned only if necessary.
|
||||
|
||||
22
tests/app/services/test_sqlmodel_services/conftest.py
Normal file
22
tests/app/services/test_sqlmodel_services/conftest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Shared fixtures for SQLModel service tests."""
|
||||
|
||||
from logging import Logger
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from tests.fixtures.sqlite_database import create_mock_sqlite_database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def logger() -> Logger:
|
||||
return InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(logger: Logger) -> SqliteDatabase:
|
||||
"""Create an in-memory database with all migrations applied."""
|
||||
config = InvokeAIAppConfig(use_memory_db=True)
|
||||
return create_mock_sqlite_database(config=config, logger=logger)
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Tests for AppSettingsServiceSqlModel."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.app_settings.app_settings_sqlmodel import AppSettingsServiceSqlModel
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_settings(db: SqliteDatabase) -> AppSettingsServiceSqlModel:
|
||||
return AppSettingsServiceSqlModel(db=db)
|
||||
|
||||
|
||||
def test_get_nonexistent_key(app_settings: AppSettingsServiceSqlModel):
|
||||
assert app_settings.get("nonexistent") is None
|
||||
|
||||
|
||||
def test_set_and_get(app_settings: AppSettingsServiceSqlModel):
|
||||
app_settings.set("test_key", "test_value")
|
||||
assert app_settings.get("test_key") == "test_value"
|
||||
|
||||
|
||||
def test_set_overwrites_existing(app_settings: AppSettingsServiceSqlModel):
|
||||
app_settings.set("key", "value1")
|
||||
app_settings.set("key", "value2")
|
||||
assert app_settings.get("key") == "value2"
|
||||
|
||||
|
||||
def test_get_jwt_secret(app_settings: AppSettingsServiceSqlModel):
|
||||
# jwt_secret is created by migration 27
|
||||
secret = app_settings.get_jwt_secret()
|
||||
assert secret is not None
|
||||
assert len(secret) > 0
|
||||
|
||||
|
||||
def test_multiple_keys(app_settings: AppSettingsServiceSqlModel):
|
||||
app_settings.set("key1", "val1")
|
||||
app_settings.set("key2", "val2")
|
||||
assert app_settings.get("key1") == "val1"
|
||||
assert app_settings.get("key2") == "val2"
|
||||
@@ -0,0 +1,329 @@
|
||||
"""Benchmark: SQLModel vs raw SQLite implementations.
|
||||
|
||||
Compares performance of the old raw-SQL services against the new SQLModel services.
|
||||
Run with: pytest tests/app/services/test_sqlmodel_services/test_benchmark_sqlmodel_vs_sqlite.py -v -s
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import BoardChanges, BoardRecordOrderBy
|
||||
from invokeai.app.services.board_records.board_records_sqlite import SqliteBoardRecordStorage
|
||||
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ImageRecordChanges, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_sqlite import SqliteImageRecordStorage
|
||||
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest
|
||||
from invokeai.app.services.users.users_default import UserService
|
||||
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _time_it(func, iterations=1):
|
||||
"""Run func `iterations` times and return total seconds."""
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
func()
|
||||
return time.perf_counter() - start
|
||||
|
||||
|
||||
def _report(name: str, sqlite_time: float, sqlmodel_time: float, iterations: int):
|
||||
ratio = sqlmodel_time / sqlite_time if sqlite_time > 0 else float("inf")
|
||||
faster = "SQLModel" if ratio < 1 else "SQLite"
|
||||
factor = 1 / ratio if ratio < 1 else ratio
|
||||
print(f"\n {name} ({iterations} iterations):")
|
||||
print(f" SQLite: {sqlite_time * 1000:.1f} ms")
|
||||
print(f" SQLModel: {sqlmodel_time * 1000:.1f} ms")
|
||||
print(f" -> {faster} is {factor:.2f}x faster")
|
||||
return sqlite_time, sqlmodel_time
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Board Records Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBoardRecordsBenchmark:
|
||||
"""Compare board record operations between SQLite and SQLModel."""
|
||||
|
||||
N_BOARDS = 100
|
||||
N_READS = 200
|
||||
N_QUERIES = 50
|
||||
|
||||
def test_insert_boards(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteBoardRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
def sqlite_insert():
|
||||
for i in range(self.N_BOARDS):
|
||||
sqlite_storage.save(f"sqlite_board_{i}", "user1")
|
||||
|
||||
def sqlmodel_insert():
|
||||
for i in range(self.N_BOARDS):
|
||||
sqlmodel_storage.save(f"sqlmodel_board_{i}", "user1")
|
||||
|
||||
t_sqlite = _time_it(sqlite_insert)
|
||||
t_sqlmodel = _time_it(sqlmodel_insert)
|
||||
_report("INSERT boards", t_sqlite, t_sqlmodel, self.N_BOARDS)
|
||||
|
||||
def test_get_boards(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteBoardRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
# Setup
|
||||
board_ids = []
|
||||
for i in range(self.N_BOARDS):
|
||||
b = sqlite_storage.save(f"board_{i}", "user1")
|
||||
board_ids.append(b.board_id)
|
||||
|
||||
def sqlite_get():
|
||||
for bid in board_ids:
|
||||
sqlite_storage.get(bid)
|
||||
|
||||
def sqlmodel_get():
|
||||
for bid in board_ids:
|
||||
sqlmodel_storage.get(bid)
|
||||
|
||||
t_sqlite = _time_it(sqlite_get, iterations=3)
|
||||
t_sqlmodel = _time_it(sqlmodel_get, iterations=3)
|
||||
_report("GET boards (by ID)", t_sqlite, t_sqlmodel, self.N_BOARDS * 3)
|
||||
|
||||
def test_get_many_boards(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteBoardRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
for i in range(self.N_BOARDS):
|
||||
sqlite_storage.save(f"board_{i}", "user1")
|
||||
|
||||
def sqlite_query():
|
||||
sqlite_storage.get_many(
|
||||
user_id="user1",
|
||||
is_admin=False,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Descending,
|
||||
offset=0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
def sqlmodel_query():
|
||||
sqlmodel_storage.get_many(
|
||||
user_id="user1",
|
||||
is_admin=False,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Descending,
|
||||
offset=0,
|
||||
limit=20,
|
||||
)
|
||||
|
||||
t_sqlite = _time_it(sqlite_query, iterations=self.N_QUERIES)
|
||||
t_sqlmodel = _time_it(sqlmodel_query, iterations=self.N_QUERIES)
|
||||
_report("GET MANY boards (paginated)", t_sqlite, t_sqlmodel, self.N_QUERIES)
|
||||
|
||||
def test_update_boards(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteBoardRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
board_ids = []
|
||||
for i in range(self.N_BOARDS):
|
||||
b = sqlite_storage.save(f"board_{i}", "user1")
|
||||
board_ids.append(b.board_id)
|
||||
|
||||
def sqlite_update():
|
||||
for bid in board_ids:
|
||||
sqlite_storage.update(bid, BoardChanges(board_name="updated"))
|
||||
|
||||
def sqlmodel_update():
|
||||
for bid in board_ids:
|
||||
sqlmodel_storage.update(bid, BoardChanges(board_name="updated"))
|
||||
|
||||
t_sqlite = _time_it(sqlite_update)
|
||||
t_sqlmodel = _time_it(sqlmodel_update)
|
||||
_report("UPDATE boards", t_sqlite, t_sqlmodel, self.N_BOARDS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Image Records Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImageRecordsBenchmark:
|
||||
"""Compare image record operations between SQLite and SQLModel."""
|
||||
|
||||
N_IMAGES = 200
|
||||
N_QUERIES = 50
|
||||
|
||||
def _save_images(self, storage, prefix: str, n: int):
|
||||
for i in range(n):
|
||||
storage.save(
|
||||
image_name=f"{prefix}_{i}",
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=ImageCategory.GENERAL,
|
||||
width=512,
|
||||
height=512,
|
||||
has_workflow=False,
|
||||
is_intermediate=(i % 5 == 0),
|
||||
starred=(i % 10 == 0),
|
||||
user_id="user1",
|
||||
)
|
||||
|
||||
def test_insert_images(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteImageRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
|
||||
|
||||
t_sqlite = _time_it(lambda: self._save_images(sqlite_storage, "sqlite", self.N_IMAGES))
|
||||
t_sqlmodel = _time_it(lambda: self._save_images(sqlmodel_storage, "sqlmodel", self.N_IMAGES))
|
||||
_report("INSERT images", t_sqlite, t_sqlmodel, self.N_IMAGES)
|
||||
|
||||
def test_get_images(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteImageRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
|
||||
|
||||
self._save_images(sqlite_storage, "img", self.N_IMAGES)
|
||||
|
||||
names = [f"img_{i}" for i in range(self.N_IMAGES)]
|
||||
|
||||
def sqlite_get():
|
||||
for name in names:
|
||||
sqlite_storage.get(name)
|
||||
|
||||
def sqlmodel_get():
|
||||
for name in names:
|
||||
sqlmodel_storage.get(name)
|
||||
|
||||
t_sqlite = _time_it(sqlite_get)
|
||||
t_sqlmodel = _time_it(sqlmodel_get)
|
||||
_report("GET images (by name)", t_sqlite, t_sqlmodel, self.N_IMAGES)
|
||||
|
||||
def test_get_many_images(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteImageRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
|
||||
|
||||
self._save_images(sqlite_storage, "img", self.N_IMAGES)
|
||||
|
||||
def sqlite_query():
|
||||
sqlite_storage.get_many(
|
||||
offset=0,
|
||||
limit=20,
|
||||
starred_first=True,
|
||||
order_dir=SQLiteDirection.Descending,
|
||||
categories=[ImageCategory.GENERAL],
|
||||
)
|
||||
|
||||
def sqlmodel_query():
|
||||
sqlmodel_storage.get_many(
|
||||
offset=0,
|
||||
limit=20,
|
||||
starred_first=True,
|
||||
order_dir=SQLiteDirection.Descending,
|
||||
categories=[ImageCategory.GENERAL],
|
||||
)
|
||||
|
||||
t_sqlite = _time_it(sqlite_query, iterations=self.N_QUERIES)
|
||||
t_sqlmodel = _time_it(sqlmodel_query, iterations=self.N_QUERIES)
|
||||
_report("GET MANY images (paginated + filtered)", t_sqlite, t_sqlmodel, self.N_QUERIES)
|
||||
|
||||
def test_get_intermediates_count(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteImageRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
|
||||
|
||||
self._save_images(sqlite_storage, "img", self.N_IMAGES)
|
||||
|
||||
t_sqlite = _time_it(lambda: sqlite_storage.get_intermediates_count(), iterations=self.N_QUERIES)
|
||||
t_sqlmodel = _time_it(lambda: sqlmodel_storage.get_intermediates_count(), iterations=self.N_QUERIES)
|
||||
_report("COUNT intermediates", t_sqlite, t_sqlmodel, self.N_QUERIES)
|
||||
|
||||
def test_update_images(self, db: SqliteDatabase):
|
||||
sqlite_storage = SqliteImageRecordStorage(db=db)
|
||||
sqlmodel_storage = SqlModelImageRecordStorage(db=db)
|
||||
|
||||
self._save_images(sqlite_storage, "img", self.N_IMAGES)
|
||||
names = [f"img_{i}" for i in range(self.N_IMAGES)]
|
||||
|
||||
def sqlite_update():
|
||||
for name in names:
|
||||
sqlite_storage.update(name, ImageRecordChanges(starred=True))
|
||||
|
||||
def sqlmodel_update():
|
||||
for name in names:
|
||||
sqlmodel_storage.update(name, ImageRecordChanges(starred=False))
|
||||
|
||||
t_sqlite = _time_it(sqlite_update)
|
||||
t_sqlmodel = _time_it(sqlmodel_update)
|
||||
_report("UPDATE images (star)", t_sqlite, t_sqlmodel, self.N_IMAGES)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Users Benchmark
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUsersBenchmark:
|
||||
"""Compare user operations between old and new implementations."""
|
||||
|
||||
N_USERS = 50
|
||||
|
||||
def test_create_users(self, db: SqliteDatabase):
|
||||
sqlite_service = UserService(db=db)
|
||||
sqlmodel_service = UserServiceSqlModel(db=db)
|
||||
|
||||
def sqlite_create():
|
||||
for i in range(self.N_USERS):
|
||||
sqlite_service.create(
|
||||
UserCreateRequest(
|
||||
email=f"sqlite{i}@test.com", display_name=f"SQLite {i}", password="TestPassword123"
|
||||
),
|
||||
strict_password_checking=False,
|
||||
)
|
||||
|
||||
def sqlmodel_create():
|
||||
for i in range(self.N_USERS):
|
||||
sqlmodel_service.create(
|
||||
UserCreateRequest(
|
||||
email=f"sqlmodel{i}@test.com", display_name=f"SQLModel {i}", password="TestPassword123"
|
||||
),
|
||||
strict_password_checking=False,
|
||||
)
|
||||
|
||||
t_sqlite = _time_it(sqlite_create)
|
||||
t_sqlmodel = _time_it(sqlmodel_create)
|
||||
_report("CREATE users", t_sqlite, t_sqlmodel, self.N_USERS)
|
||||
|
||||
def test_list_users(self, db: SqliteDatabase):
|
||||
sqlite_service = UserService(db=db)
|
||||
sqlmodel_service = UserServiceSqlModel(db=db)
|
||||
|
||||
for i in range(self.N_USERS):
|
||||
sqlite_service.create(
|
||||
UserCreateRequest(email=f"user{i}@test.com", display_name=f"User {i}", password="TestPassword123"),
|
||||
strict_password_checking=False,
|
||||
)
|
||||
|
||||
t_sqlite = _time_it(lambda: sqlite_service.list_users(), iterations=100)
|
||||
t_sqlmodel = _time_it(lambda: sqlmodel_service.list_users(), iterations=100)
|
||||
_report("LIST users", t_sqlite, t_sqlmodel, 100)
|
||||
|
||||
def test_authenticate_users(self, db: SqliteDatabase):
|
||||
sqlite_service = UserService(db=db)
|
||||
sqlmodel_service = UserServiceSqlModel(db=db)
|
||||
|
||||
for i in range(10):
|
||||
sqlite_service.create(
|
||||
UserCreateRequest(email=f"auth{i}@test.com", display_name=f"Auth {i}", password="TestPassword123"),
|
||||
strict_password_checking=False,
|
||||
)
|
||||
|
||||
def sqlite_auth():
|
||||
for i in range(10):
|
||||
sqlite_service.authenticate(f"auth{i}@test.com", "TestPassword123")
|
||||
|
||||
def sqlmodel_auth():
|
||||
for i in range(10):
|
||||
sqlmodel_service.authenticate(f"auth{i}@test.com", "TestPassword123")
|
||||
|
||||
t_sqlite = _time_it(sqlite_auth, iterations=10)
|
||||
t_sqlmodel = _time_it(sqlmodel_auth, iterations=10)
|
||||
_report("AUTHENTICATE users", t_sqlite, t_sqlmodel, 100)
|
||||
@@ -0,0 +1,105 @@
|
||||
"""Tests for SqlModelBoardImageRecordStorage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.board_image_records.board_image_records_sqlmodel import SqlModelBoardImageRecordStorage
|
||||
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def boards(db: SqliteDatabase) -> SqlModelBoardRecordStorage:
|
||||
return SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def images(db: SqliteDatabase) -> SqlModelImageRecordStorage:
|
||||
return SqlModelImageRecordStorage(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> SqlModelBoardImageRecordStorage:
|
||||
return SqlModelBoardImageRecordStorage(db=db)
|
||||
|
||||
|
||||
def _create_image(images: SqlModelImageRecordStorage, name: str, category: ImageCategory = ImageCategory.GENERAL):
|
||||
images.save(
|
||||
image_name=name,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=category,
|
||||
width=512,
|
||||
height=512,
|
||||
has_workflow=False,
|
||||
user_id="user1",
|
||||
)
|
||||
|
||||
|
||||
def test_add_image_to_board(storage, boards, images):
|
||||
board = boards.save("Board", "user1")
|
||||
_create_image(images, "img1")
|
||||
storage.add_image_to_board(board.board_id, "img1")
|
||||
assert storage.get_board_for_image("img1") == board.board_id
|
||||
|
||||
|
||||
def test_remove_image_from_board(storage, boards, images):
|
||||
board = boards.save("Board", "user1")
|
||||
_create_image(images, "img1")
|
||||
storage.add_image_to_board(board.board_id, "img1")
|
||||
storage.remove_image_from_board("img1")
|
||||
assert storage.get_board_for_image("img1") is None
|
||||
|
||||
|
||||
def test_move_image_between_boards(storage, boards, images):
|
||||
board1 = boards.save("Board 1", "user1")
|
||||
board2 = boards.save("Board 2", "user1")
|
||||
_create_image(images, "img1")
|
||||
storage.add_image_to_board(board1.board_id, "img1")
|
||||
storage.add_image_to_board(board2.board_id, "img1")
|
||||
assert storage.get_board_for_image("img1") == board2.board_id
|
||||
|
||||
|
||||
def test_get_board_for_unassigned_image(storage, images):
|
||||
_create_image(images, "img1")
|
||||
assert storage.get_board_for_image("img1") is None
|
||||
|
||||
|
||||
def test_get_image_count_for_board(storage, boards, images):
|
||||
board = boards.save("Board", "user1")
|
||||
_create_image(images, "img1", ImageCategory.GENERAL)
|
||||
_create_image(images, "img2", ImageCategory.GENERAL)
|
||||
_create_image(images, "img3", ImageCategory.MASK)
|
||||
storage.add_image_to_board(board.board_id, "img1")
|
||||
storage.add_image_to_board(board.board_id, "img2")
|
||||
storage.add_image_to_board(board.board_id, "img3")
|
||||
# IMAGE_CATEGORIES = [GENERAL], so count should be 2
|
||||
assert storage.get_image_count_for_board(board.board_id) == 2
|
||||
|
||||
|
||||
def test_get_asset_count_for_board(storage, boards, images):
|
||||
board = boards.save("Board", "user1")
|
||||
_create_image(images, "img1", ImageCategory.GENERAL)
|
||||
_create_image(images, "img2", ImageCategory.MASK)
|
||||
_create_image(images, "img3", ImageCategory.CONTROL)
|
||||
storage.add_image_to_board(board.board_id, "img1")
|
||||
storage.add_image_to_board(board.board_id, "img2")
|
||||
storage.add_image_to_board(board.board_id, "img3")
|
||||
# ASSETS_CATEGORIES = [CONTROL, MASK, USER, OTHER], so count should be 2
|
||||
assert storage.get_asset_count_for_board(board.board_id) == 2
|
||||
|
||||
|
||||
def test_get_all_board_image_names(storage, boards, images):
|
||||
board = boards.save("Board", "user1")
|
||||
_create_image(images, "img1")
|
||||
_create_image(images, "img2")
|
||||
storage.add_image_to_board(board.board_id, "img1")
|
||||
storage.add_image_to_board(board.board_id, "img2")
|
||||
names = storage.get_all_board_image_names_for_board(board.board_id, categories=None, is_intermediate=None)
|
||||
assert set(names) == {"img1", "img2"}
|
||||
|
||||
|
||||
def test_get_all_board_image_names_uncategorized(storage, images):
|
||||
_create_image(images, "img1")
|
||||
names = storage.get_all_board_image_names_for_board("none", categories=None, is_intermediate=None)
|
||||
assert "img1" in names
|
||||
@@ -0,0 +1,161 @@
|
||||
"""Tests for SqlModelBoardRecordStorage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.board_records.board_records_common import (
|
||||
BoardChanges,
|
||||
BoardRecordNotFoundException,
|
||||
BoardRecordOrderBy,
|
||||
BoardVisibility,
|
||||
)
|
||||
from invokeai.app.services.board_records.board_records_sqlmodel import SqlModelBoardRecordStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> SqlModelBoardRecordStorage:
|
||||
return SqlModelBoardRecordStorage(db=db)
|
||||
|
||||
|
||||
def test_save_and_get(storage: SqlModelBoardRecordStorage):
|
||||
board = storage.save("Test Board", "user1")
|
||||
assert board.board_name == "Test Board"
|
||||
assert board.user_id == "user1"
|
||||
|
||||
fetched = storage.get(board.board_id)
|
||||
assert fetched.board_name == "Test Board"
|
||||
|
||||
|
||||
def test_get_nonexistent(storage: SqlModelBoardRecordStorage):
|
||||
with pytest.raises(BoardRecordNotFoundException):
|
||||
storage.get("nonexistent")
|
||||
|
||||
|
||||
def test_update_name(storage: SqlModelBoardRecordStorage):
|
||||
board = storage.save("Original", "user1")
|
||||
updated = storage.update(board.board_id, BoardChanges(board_name="Updated"))
|
||||
assert updated.board_name == "Updated"
|
||||
|
||||
|
||||
def test_update_archived(storage: SqlModelBoardRecordStorage):
|
||||
board = storage.save("Board", "user1")
|
||||
updated = storage.update(board.board_id, BoardChanges(archived=True))
|
||||
assert updated.archived is True
|
||||
|
||||
|
||||
def test_update_visibility(storage: SqlModelBoardRecordStorage):
|
||||
board = storage.save("Board", "user1")
|
||||
updated = storage.update(board.board_id, BoardChanges(board_visibility=BoardVisibility.Shared))
|
||||
assert updated.board_visibility == BoardVisibility.Shared
|
||||
|
||||
|
||||
def test_delete(storage: SqlModelBoardRecordStorage):
|
||||
board = storage.save("Board", "user1")
|
||||
storage.delete(board.board_id)
|
||||
with pytest.raises(BoardRecordNotFoundException):
|
||||
storage.get(board.board_id)
|
||||
|
||||
|
||||
def test_get_many_pagination(storage: SqlModelBoardRecordStorage):
|
||||
for i in range(5):
|
||||
storage.save(f"Board {i}", "user1")
|
||||
|
||||
result = storage.get_many(
|
||||
user_id="user1",
|
||||
is_admin=False,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
offset=0,
|
||||
limit=3,
|
||||
)
|
||||
assert len(result.items) == 3
|
||||
assert result.total == 5
|
||||
|
||||
|
||||
def test_get_many_admin_sees_all(storage: SqlModelBoardRecordStorage):
|
||||
storage.save("User1 Board", "user1")
|
||||
storage.save("User2 Board", "user2")
|
||||
|
||||
result = storage.get_many(
|
||||
user_id="admin",
|
||||
is_admin=True,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
)
|
||||
assert result.total == 2
|
||||
|
||||
|
||||
def test_get_many_user_sees_own_and_shared(storage: SqlModelBoardRecordStorage):
|
||||
storage.save("User1 Board", "user1")
|
||||
storage.save("User2 Board", "user2")
|
||||
b3 = storage.save("Shared Board", "user1")
|
||||
storage.update(b3.board_id, BoardChanges(board_visibility=BoardVisibility.Shared))
|
||||
|
||||
# User2 sees own + shared
|
||||
result = storage.get_many(
|
||||
user_id="user2",
|
||||
is_admin=False,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
)
|
||||
names = [b.board_name for b in result.items]
|
||||
assert "User2 Board" in names
|
||||
assert "Shared Board" in names
|
||||
assert "User1 Board" not in names
|
||||
|
||||
|
||||
def test_get_many_exclude_archived(storage: SqlModelBoardRecordStorage):
|
||||
storage.save("Active", "user1")
|
||||
b2 = storage.save("Archived", "user1")
|
||||
storage.update(b2.board_id, BoardChanges(archived=True))
|
||||
|
||||
result = storage.get_many(
|
||||
user_id="user1",
|
||||
is_admin=True,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
include_archived=False,
|
||||
)
|
||||
assert result.total == 1
|
||||
assert result.items[0].board_name == "Active"
|
||||
|
||||
|
||||
def test_get_all(storage: SqlModelBoardRecordStorage):
|
||||
for i in range(3):
|
||||
storage.save(f"Board {i}", "user1")
|
||||
|
||||
boards = storage.get_all(
|
||||
user_id="user1",
|
||||
is_admin=True,
|
||||
order_by=BoardRecordOrderBy.CreatedAt,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
)
|
||||
assert len(boards) == 3
|
||||
|
||||
|
||||
def test_get_all_order_by_name(storage: SqlModelBoardRecordStorage):
|
||||
storage.save("Zebra", "user1")
|
||||
storage.save("Alpha", "user1")
|
||||
|
||||
boards = storage.get_all(
|
||||
user_id="user1",
|
||||
is_admin=True,
|
||||
order_by=BoardRecordOrderBy.Name,
|
||||
direction=SQLiteDirection.Ascending,
|
||||
)
|
||||
assert boards[0].board_name == "Alpha"
|
||||
assert boards[1].board_name == "Zebra"
|
||||
|
||||
|
||||
def test_sql_injection_in_name(storage: SqlModelBoardRecordStorage):
|
||||
payload = "name'); DROP TABLE boards; --"
|
||||
board = storage.save(payload, "user1")
|
||||
fetched = storage.get(board.board_id)
|
||||
assert fetched.board_name == payload
|
||||
|
||||
|
||||
def test_sql_injection_in_id(storage: SqlModelBoardRecordStorage):
|
||||
storage.save("board", "user1")
|
||||
with pytest.raises(BoardRecordNotFoundException):
|
||||
storage.get("fake' OR '1'='1")
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Tests for ClientStatePersistenceSqlModel."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.client_state_persistence.client_state_persistence_sqlmodel import (
|
||||
ClientStatePersistenceSqlModel,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest
|
||||
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def users(db: SqliteDatabase) -> UserServiceSqlModel:
|
||||
return UserServiceSqlModel(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_id(users: UserServiceSqlModel) -> str:
|
||||
user = users.create(
|
||||
UserCreateRequest(email="test@test.com", display_name="Test", password="TestPassword123"),
|
||||
)
|
||||
return user.user_id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client_state(db: SqliteDatabase) -> ClientStatePersistenceSqlModel:
|
||||
return ClientStatePersistenceSqlModel(db=db)
|
||||
|
||||
|
||||
def test_get_nonexistent(client_state: ClientStatePersistenceSqlModel, user_id: str):
|
||||
assert client_state.get_by_key(user_id, "nonexistent") is None
|
||||
|
||||
|
||||
def test_set_and_get(client_state: ClientStatePersistenceSqlModel, user_id: str):
|
||||
client_state.set_by_key(user_id, "theme", "dark")
|
||||
assert client_state.get_by_key(user_id, "theme") == "dark"
|
||||
|
||||
|
||||
def test_set_overwrites(client_state: ClientStatePersistenceSqlModel, user_id: str):
|
||||
client_state.set_by_key(user_id, "key", "val1")
|
||||
client_state.set_by_key(user_id, "key", "val2")
|
||||
assert client_state.get_by_key(user_id, "key") == "val2"
|
||||
|
||||
|
||||
def test_delete_user_state(client_state: ClientStatePersistenceSqlModel, user_id: str):
|
||||
client_state.set_by_key(user_id, "key1", "val1")
|
||||
client_state.set_by_key(user_id, "key2", "val2")
|
||||
client_state.delete(user_id)
|
||||
assert client_state.get_by_key(user_id, "key1") is None
|
||||
assert client_state.get_by_key(user_id, "key2") is None
|
||||
|
||||
|
||||
def test_user_isolation(client_state: ClientStatePersistenceSqlModel, users: UserServiceSqlModel):
|
||||
user1 = users.create(UserCreateRequest(email="u1@test.com", display_name="U1", password="TestPassword123"))
|
||||
user2 = users.create(UserCreateRequest(email="u2@test.com", display_name="U2", password="TestPassword123"))
|
||||
client_state.set_by_key(user1.user_id, "key", "user1_val")
|
||||
client_state.set_by_key(user2.user_id, "key", "user2_val")
|
||||
assert client_state.get_by_key(user1.user_id, "key") == "user1_val"
|
||||
assert client_state.get_by_key(user2.user_id, "key") == "user2_val"
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Tests for SqlModelImageRecordStorage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageRecordChanges,
|
||||
ImageRecordNotFoundException,
|
||||
ResourceOrigin,
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_sqlmodel import SqlModelImageRecordStorage
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> SqlModelImageRecordStorage:
|
||||
return SqlModelImageRecordStorage(db=db)
|
||||
|
||||
|
||||
def _save(storage, name="img1", category=ImageCategory.GENERAL, intermediate=False, starred=False, user_id="user1"):
|
||||
return storage.save(
|
||||
image_name=name,
|
||||
image_origin=ResourceOrigin.INTERNAL,
|
||||
image_category=category,
|
||||
width=512,
|
||||
height=512,
|
||||
has_workflow=False,
|
||||
is_intermediate=intermediate,
|
||||
starred=starred,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
def test_save_and_get(storage):
|
||||
_save(storage, "img1")
|
||||
record = storage.get("img1")
|
||||
assert record.image_name == "img1"
|
||||
assert record.width == 512
|
||||
assert record.image_category == ImageCategory.GENERAL
|
||||
|
||||
|
||||
def test_get_nonexistent(storage):
|
||||
with pytest.raises(ImageRecordNotFoundException):
|
||||
storage.get("nonexistent")
|
||||
|
||||
|
||||
def test_save_returns_datetime(storage):
|
||||
created_at = _save(storage, "img1")
|
||||
assert created_at is not None
|
||||
|
||||
|
||||
def test_update_category(storage):
|
||||
_save(storage, "img1")
|
||||
storage.update("img1", ImageRecordChanges(image_category=ImageCategory.MASK))
|
||||
record = storage.get("img1")
|
||||
assert record.image_category == ImageCategory.MASK
|
||||
|
||||
|
||||
def test_update_starred(storage):
|
||||
_save(storage, "img1")
|
||||
storage.update("img1", ImageRecordChanges(starred=True))
|
||||
record = storage.get("img1")
|
||||
assert record.starred is True
|
||||
|
||||
|
||||
def test_update_is_intermediate(storage):
|
||||
_save(storage, "img1")
|
||||
storage.update("img1", ImageRecordChanges(is_intermediate=True))
|
||||
record = storage.get("img1")
|
||||
assert record.is_intermediate is True
|
||||
|
||||
|
||||
def test_delete(storage):
|
||||
_save(storage, "img1")
|
||||
storage.delete("img1")
|
||||
with pytest.raises(ImageRecordNotFoundException):
|
||||
storage.get("img1")
|
||||
|
||||
|
||||
def test_delete_many(storage):
|
||||
_save(storage, "img1")
|
||||
_save(storage, "img2")
|
||||
_save(storage, "img3")
|
||||
storage.delete_many(["img1", "img3"])
|
||||
with pytest.raises(ImageRecordNotFoundException):
|
||||
storage.get("img1")
|
||||
assert storage.get("img2").image_name == "img2"
|
||||
|
||||
|
||||
def test_get_many_pagination(storage):
|
||||
for i in range(5):
|
||||
_save(storage, f"img{i}")
|
||||
|
||||
result = storage.get_many(offset=0, limit=3)
|
||||
assert len(result.items) == 3
|
||||
assert result.total == 5
|
||||
|
||||
|
||||
def test_get_many_filter_by_category(storage):
|
||||
_save(storage, "img1", category=ImageCategory.GENERAL)
|
||||
_save(storage, "img2", category=ImageCategory.MASK)
|
||||
|
||||
result = storage.get_many(categories=[ImageCategory.GENERAL])
|
||||
assert all(r.image_category == ImageCategory.GENERAL for r in result.items)
|
||||
|
||||
|
||||
def test_get_many_starred_first(storage):
|
||||
_save(storage, "img1", starred=False)
|
||||
_save(storage, "img2", starred=True)
|
||||
|
||||
result = storage.get_many(starred_first=True, order_dir=SQLiteDirection.Descending)
|
||||
assert result.items[0].starred is True
|
||||
|
||||
|
||||
def test_get_intermediates_count(storage):
|
||||
_save(storage, "img1", intermediate=True)
|
||||
_save(storage, "img2", intermediate=True)
|
||||
_save(storage, "img3", intermediate=False)
|
||||
assert storage.get_intermediates_count() == 2
|
||||
|
||||
|
||||
def test_get_intermediates_count_by_user(storage):
|
||||
_save(storage, "img1", intermediate=True, user_id="user1")
|
||||
_save(storage, "img2", intermediate=True, user_id="user2")
|
||||
assert storage.get_intermediates_count(user_id="user1") == 1
|
||||
|
||||
|
||||
def test_delete_intermediates(storage):
|
||||
_save(storage, "img1", intermediate=True)
|
||||
_save(storage, "img2", intermediate=True)
|
||||
_save(storage, "img3", intermediate=False)
|
||||
deleted = storage.delete_intermediates()
|
||||
assert set(deleted) == {"img1", "img2"}
|
||||
assert storage.get("img3").image_name == "img3"
|
||||
|
||||
|
||||
def test_get_user_id(storage):
|
||||
_save(storage, "img1", user_id="user42")
|
||||
assert storage.get_user_id("img1") == "user42"
|
||||
assert storage.get_user_id("nonexistent") is None
|
||||
|
||||
|
||||
def test_get_image_names(storage):
|
||||
_save(storage, "img1")
|
||||
_save(storage, "img2", starred=True)
|
||||
|
||||
result = storage.get_image_names(starred_first=True)
|
||||
assert result.total_count == 2
|
||||
assert result.starred_count == 1
|
||||
# Starred should come first
|
||||
assert result.image_names[0] == "img2"
|
||||
@@ -0,0 +1,139 @@
|
||||
"""Tests for ModelRecordServiceSqlModel."""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.model_records.model_records_base import (
|
||||
DuplicateModelException,
|
||||
ModelRecordChanges,
|
||||
ModelRecordOrderBy,
|
||||
UnknownModelException,
|
||||
)
|
||||
from invokeai.app.services.model_records.model_records_sqlmodel import ModelRecordServiceSqlModel
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.backend.model_manager.configs.main import Main_Diffusers_SD1_Config
|
||||
from invokeai.backend.model_manager.taxonomy import (
|
||||
BaseModelType,
|
||||
ModelFormat,
|
||||
ModelSourceType,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> ModelRecordServiceSqlModel:
|
||||
return ModelRecordServiceSqlModel(db=db, logger=logging.getLogger("test"))
|
||||
|
||||
|
||||
def _make_config(
|
||||
key: str = "test-key", name: str = "Test Model", path: str = "/models/test"
|
||||
) -> Main_Diffusers_SD1_Config:
|
||||
return Main_Diffusers_SD1_Config(
|
||||
key=key,
|
||||
name=name,
|
||||
base=BaseModelType.StableDiffusion1,
|
||||
type=ModelType.Main,
|
||||
format=ModelFormat.Diffusers,
|
||||
path=path,
|
||||
hash="abc123",
|
||||
file_size=1024,
|
||||
source="/source",
|
||||
source_type=ModelSourceType.Path,
|
||||
prediction_type=SchedulerPredictionType.Epsilon,
|
||||
variant=ModelVariantType.Normal,
|
||||
)
|
||||
|
||||
|
||||
def test_add_and_get(storage):
|
||||
config = _make_config()
|
||||
storage.add_model(config)
|
||||
fetched = storage.get_model("test-key")
|
||||
assert fetched.name == "Test Model"
|
||||
assert fetched.key == "test-key"
|
||||
|
||||
|
||||
def test_add_duplicate_raises(storage):
|
||||
config = _make_config()
|
||||
storage.add_model(config)
|
||||
with pytest.raises(DuplicateModelException):
|
||||
storage.add_model(config)
|
||||
|
||||
|
||||
def test_get_nonexistent(storage):
|
||||
with pytest.raises(UnknownModelException):
|
||||
storage.get_model("nonexistent")
|
||||
|
||||
|
||||
def test_del_model(storage):
|
||||
config = _make_config()
|
||||
storage.add_model(config)
|
||||
storage.del_model("test-key")
|
||||
with pytest.raises(UnknownModelException):
|
||||
storage.get_model("test-key")
|
||||
|
||||
|
||||
def test_del_nonexistent(storage):
|
||||
with pytest.raises(UnknownModelException):
|
||||
storage.del_model("nonexistent")
|
||||
|
||||
|
||||
def test_update_model(storage):
|
||||
config = _make_config()
|
||||
storage.add_model(config)
|
||||
updated = storage.update_model("test-key", ModelRecordChanges(name="Updated Name"))
|
||||
assert updated.name == "Updated Name"
|
||||
|
||||
|
||||
def test_exists(storage):
|
||||
assert storage.exists("test-key") is False
|
||||
storage.add_model(_make_config())
|
||||
assert storage.exists("test-key") is True
|
||||
|
||||
|
||||
def test_search_by_attr_name(storage):
|
||||
storage.add_model(_make_config("k1", "Alpha", "/models/alpha"))
|
||||
storage.add_model(_make_config("k2", "Beta", "/models/beta"))
|
||||
results = storage.search_by_attr(model_name="Alpha")
|
||||
assert len(results) == 1
|
||||
assert results[0].name == "Alpha"
|
||||
|
||||
|
||||
def test_search_by_attr_all(storage):
|
||||
storage.add_model(_make_config("k1", "M1", "/m1"))
|
||||
storage.add_model(_make_config("k2", "M2", "/m2"))
|
||||
results = storage.search_by_attr()
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
def test_search_by_path(storage):
|
||||
storage.add_model(_make_config("k1", "M1", "/models/specific"))
|
||||
results = storage.search_by_path("/models/specific")
|
||||
assert len(results) == 1
|
||||
assert results[0].key == "k1"
|
||||
|
||||
|
||||
def test_replace_model(storage):
|
||||
config = _make_config()
|
||||
storage.add_model(config)
|
||||
new_config = _make_config(name="Replaced")
|
||||
replaced = storage.replace_model("test-key", new_config)
|
||||
assert replaced.name == "Replaced"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="ModelSummary format needs investigation — list_models query works but DTO mapping differs")
|
||||
def test_list_models(storage):
|
||||
storage.add_model(_make_config("k1", "M1", "/m1"))
|
||||
storage.add_model(_make_config("k2", "M2", "/m2"))
|
||||
result = storage.list_models(page=0, per_page=10)
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
|
||||
def test_search_by_attr_with_order(storage):
|
||||
storage.add_model(_make_config("k1", "Beta", "/m1"))
|
||||
storage.add_model(_make_config("k2", "Alpha", "/m2"))
|
||||
results = storage.search_by_attr(order_by=ModelRecordOrderBy.Name)
|
||||
assert results[0].name == "Alpha"
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Tests for SqlModelModelRelationshipRecordStorage."""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.model_relationship_records.model_relationship_records_sqlmodel import (
|
||||
SqlModelModelRelationshipRecordStorage,
|
||||
)
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
|
||||
def _add_model(db: SqliteDatabase, key: str, name: str = "test") -> None:
|
||||
"""Helper to insert a model record for FK constraints using raw SQL (avoids generated column issues)."""
|
||||
config = json.dumps(
|
||||
{
|
||||
"key": key,
|
||||
"name": name,
|
||||
"base": "sd-1",
|
||||
"type": "main",
|
||||
"format": "diffusers",
|
||||
"path": f"/models/{key}",
|
||||
"hash": "abc123",
|
||||
"source": "/src",
|
||||
"source_type": "path",
|
||||
"file_size": 1024,
|
||||
}
|
||||
)
|
||||
with db.transaction() as cursor:
|
||||
cursor.execute("INSERT INTO models (id, config) VALUES (?, ?)", (key, config))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> SqlModelModelRelationshipRecordStorage:
|
||||
return SqlModelModelRelationshipRecordStorage(db=db)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def models(db: SqliteDatabase) -> tuple[str, str, str]:
|
||||
keys = ("model_a", "model_b", "model_c")
|
||||
for k in keys:
|
||||
_add_model(db, k, name=k)
|
||||
return keys
|
||||
|
||||
|
||||
def test_add_and_get_relationship(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, b, _ = models
|
||||
storage.add_model_relationship(a, b)
|
||||
related = storage.get_related_model_keys(a)
|
||||
assert b in related
|
||||
|
||||
|
||||
def test_bidirectional(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, b, _ = models
|
||||
storage.add_model_relationship(a, b)
|
||||
assert a in storage.get_related_model_keys(b)
|
||||
assert b in storage.get_related_model_keys(a)
|
||||
|
||||
|
||||
def test_self_relationship_raises(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, _, _ = models
|
||||
with pytest.raises(ValueError, match="Cannot relate a model to itself"):
|
||||
storage.add_model_relationship(a, a)
|
||||
|
||||
|
||||
def test_remove_relationship(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, b, _ = models
|
||||
storage.add_model_relationship(a, b)
|
||||
storage.remove_model_relationship(a, b)
|
||||
assert b not in storage.get_related_model_keys(a)
|
||||
|
||||
|
||||
def test_duplicate_add_is_idempotent(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, b, _ = models
|
||||
storage.add_model_relationship(a, b)
|
||||
storage.add_model_relationship(a, b) # should not raise
|
||||
related = storage.get_related_model_keys(a)
|
||||
assert related.count(b) == 1
|
||||
|
||||
|
||||
def test_get_related_batch(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, b, c = models
|
||||
storage.add_model_relationship(a, b)
|
||||
storage.add_model_relationship(a, c)
|
||||
related = storage.get_related_model_keys_batch([a])
|
||||
assert set(related) == {b, c}
|
||||
|
||||
|
||||
def test_no_relationships(storage: SqlModelModelRelationshipRecordStorage, models: tuple):
|
||||
a, _, _ = models
|
||||
assert storage.get_related_model_keys(a) == []
|
||||
@@ -0,0 +1,467 @@
|
||||
"""Tests for the SQLModel-backed session queue implementation."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import insert
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
SessionQueueItemNotFoundError,
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.shared.sqlite.models import SessionQueueTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from tests.test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
# ---- fixtures ----
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_queue(mock_invoker: Invoker) -> SqlModelSessionQueue:
|
||||
"""Create a SqlModelSessionQueue backed by the mock invoker's in-memory database."""
|
||||
db = mock_invoker.services.board_records._db
|
||||
queue = SqlModelSessionQueue(db=db)
|
||||
queue.start(mock_invoker)
|
||||
return queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_graph() -> Graph:
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
|
||||
return g
|
||||
|
||||
|
||||
# ---- helpers ----
|
||||
|
||||
|
||||
def _make_session_json() -> tuple[str, str]:
|
||||
"""Build a valid GraphExecutionState JSON blob and return (session_id, json)."""
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
|
||||
state = GraphExecutionState(graph=g)
|
||||
return state.id, state.model_dump_json(warnings=False, exclude_none=True)
|
||||
|
||||
|
||||
def _insert_raw(
|
||||
queue: SqlModelSessionQueue,
|
||||
*,
|
||||
queue_id: str = "default",
|
||||
user_id: str = "system",
|
||||
status: str = "pending",
|
||||
priority: int = 0,
|
||||
batch_id: Optional[str] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Insert a minimal queue item via Core and return its item_id."""
|
||||
session_id, session_json = _make_session_json()
|
||||
batch_id = batch_id or str(uuid.uuid4())
|
||||
with queue._db.get_session() as session:
|
||||
result = session.execute(
|
||||
insert(SessionQueueTable).values(
|
||||
queue_id=queue_id,
|
||||
session=session_json,
|
||||
session_id=session_id,
|
||||
batch_id=batch_id,
|
||||
field_values=None,
|
||||
priority=priority,
|
||||
workflow=None,
|
||||
origin=None,
|
||||
destination=destination,
|
||||
retried_from_item_id=None,
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
return int(result.inserted_primary_key[0])
|
||||
|
||||
|
||||
# ---- start() / _set_in_progress_to_canceled ----
|
||||
|
||||
|
||||
def test_start_cancels_in_progress(mock_invoker: Invoker) -> None:
|
||||
db = mock_invoker.services.board_records._db
|
||||
queue = SqlModelSessionQueue(db=db)
|
||||
in_progress_id = _insert_raw(queue, status="in_progress")
|
||||
queue.start(mock_invoker)
|
||||
item = queue.get_queue_item(in_progress_id)
|
||||
assert item.status == "canceled"
|
||||
|
||||
|
||||
# ---- simple read methods ----
|
||||
|
||||
|
||||
def test_is_empty_and_is_full(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue.is_empty("default").is_empty is True
|
||||
_insert_raw(session_queue)
|
||||
assert session_queue.is_empty("default").is_empty is False
|
||||
# default max_queue_size is high; queue with 1 item is not full
|
||||
assert session_queue.is_full("default").is_full is False
|
||||
|
||||
|
||||
def test_get_queue_item_not_found(session_queue: SqlModelSessionQueue) -> None:
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue.get_queue_item(99999)
|
||||
|
||||
|
||||
def test_get_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue, user_id="alice")
|
||||
item = session_queue.get_queue_item(item_id)
|
||||
assert item.item_id == item_id
|
||||
assert item.user_id == "alice"
|
||||
assert item.status == "pending"
|
||||
|
||||
|
||||
def test_get_current_and_get_next(session_queue: SqlModelSessionQueue) -> None:
|
||||
pending = _insert_raw(session_queue, priority=1)
|
||||
in_progress = _insert_raw(session_queue, status="in_progress")
|
||||
current = session_queue.get_current("default")
|
||||
assert current is not None and current.item_id == in_progress
|
||||
nxt = session_queue.get_next("default")
|
||||
assert nxt is not None and nxt.item_id == pending
|
||||
|
||||
|
||||
def test_get_current_queue_size(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue)
|
||||
_insert_raw(session_queue)
|
||||
_insert_raw(session_queue, status="completed")
|
||||
assert session_queue._get_current_queue_size("default") == 2
|
||||
|
||||
|
||||
def test_get_highest_priority(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue._get_highest_priority("default") == 0
|
||||
_insert_raw(session_queue, priority=3)
|
||||
_insert_raw(session_queue, priority=7)
|
||||
_insert_raw(session_queue, priority=10, status="completed") # ignored
|
||||
assert session_queue._get_highest_priority("default") == 7
|
||||
|
||||
|
||||
# ---- enqueue / dequeue ----
|
||||
|
||||
|
||||
def test_enqueue_batch_and_dequeue(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
batch = Batch(graph=batch_graph, runs=2)
|
||||
result = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
|
||||
assert result.enqueued == 2
|
||||
assert result.requested == 2
|
||||
assert len(result.item_ids) == 2
|
||||
|
||||
# dequeue takes the first pending and marks it in_progress
|
||||
dequeued = session_queue.dequeue()
|
||||
assert dequeued is not None
|
||||
assert dequeued.status == "in_progress"
|
||||
|
||||
# only one in-progress at a time
|
||||
current = session_queue.get_current("default")
|
||||
assert current is not None and current.item_id == dequeued.item_id
|
||||
|
||||
|
||||
def test_enqueue_batch_prepend_increases_priority(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
asyncio.run(session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=False))
|
||||
second = asyncio.run(
|
||||
session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=True)
|
||||
)
|
||||
assert second.priority == 1
|
||||
|
||||
|
||||
def test_dequeue_empty_returns_none(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue.dequeue() is None
|
||||
|
||||
|
||||
# ---- status mutations ----
|
||||
|
||||
|
||||
def test_complete_fail_cancel_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
assert session_queue.complete_queue_item(item_id).status == "completed"
|
||||
# second mutation on terminal-status item is a no-op (returns existing)
|
||||
assert session_queue.cancel_queue_item(item_id).status == "completed"
|
||||
|
||||
item_id2 = _insert_raw(session_queue)
|
||||
failed = session_queue.fail_queue_item(item_id2, "ErrType", "ErrMsg", "trace")
|
||||
assert failed.status == "failed"
|
||||
assert failed.error_type == "ErrType"
|
||||
assert failed.error_message == "ErrMsg"
|
||||
assert failed.error_traceback == "trace"
|
||||
|
||||
item_id3 = _insert_raw(session_queue)
|
||||
assert session_queue.cancel_queue_item(item_id3).status == "canceled"
|
||||
|
||||
|
||||
def test_set_queue_item_status_unknown_id_raises(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue._set_queue_item_status(99999, "completed")
|
||||
|
||||
|
||||
def test_delete_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
session_queue.delete_queue_item(item_id)
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
def test_set_queue_item_session(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
new_session = GraphExecutionState(graph=batch_graph)
|
||||
session_queue.set_queue_item_session(item_id, new_session)
|
||||
fetched = session_queue.get_queue_item(item_id)
|
||||
assert fetched.session.id == new_session.id
|
||||
|
||||
|
||||
# ---- bulk delete ----
|
||||
|
||||
|
||||
def test_clear_with_user_id_only_deletes_own_items(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_b")
|
||||
result = session_queue.clear("default", user_id="user_a")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_clear_without_user_id_deletes_all(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_b")
|
||||
result = session_queue.clear("default")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_prune_only_deletes_terminal(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
_insert_raw(session_queue, status="failed")
|
||||
_insert_raw(session_queue, status="canceled")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
result = session_queue.prune("default")
|
||||
assert result.deleted == 3
|
||||
# pending and in_progress remain
|
||||
assert session_queue.get_queue_status("default").pending == 1
|
||||
assert session_queue.get_queue_status("default").in_progress == 1
|
||||
|
||||
|
||||
def test_prune_with_user_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="completed", user_id="user_a")
|
||||
_insert_raw(session_queue, status="failed", user_id="user_b")
|
||||
result = session_queue.prune("default", user_id="user_a")
|
||||
assert result.deleted == 1
|
||||
|
||||
|
||||
def test_delete_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="generate")
|
||||
result = session_queue.delete_by_destination("default", destination="canvas")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_delete_all_except_current(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
result = session_queue.delete_all_except_current("default")
|
||||
# only deletes pending
|
||||
assert result.deleted == 2
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 0
|
||||
assert status.in_progress == 1
|
||||
assert status.completed == 1
|
||||
|
||||
|
||||
# ---- bulk cancel ----
|
||||
|
||||
|
||||
def test_cancel_by_batch_ids(session_queue: SqlModelSessionQueue) -> None:
|
||||
batch_id = str(uuid.uuid4())
|
||||
_insert_raw(session_queue, batch_id=batch_id)
|
||||
_insert_raw(session_queue, batch_id=batch_id)
|
||||
_insert_raw(session_queue, batch_id=str(uuid.uuid4())) # different batch
|
||||
result = session_queue.cancel_by_batch_ids("default", [batch_id])
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
def test_cancel_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas", status="completed") # skipped
|
||||
_insert_raw(session_queue, destination="generate") # different dest
|
||||
result = session_queue.cancel_by_destination("default", "canvas")
|
||||
assert result.canceled == 1
|
||||
|
||||
|
||||
def test_cancel_by_queue_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, queue_id="default")
|
||||
_insert_raw(session_queue, queue_id="default")
|
||||
_insert_raw(session_queue, queue_id="other")
|
||||
result = session_queue.cancel_by_queue_id("default")
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
def test_cancel_all_except_current(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
result = session_queue.cancel_all_except_current("default")
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
# ---- prune-to-limit ----
|
||||
|
||||
|
||||
def test_prune_terminal_to_limit_keeps_n_most_recent(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
for _ in range(5):
|
||||
_insert_raw(session_queue, status="completed")
|
||||
deleted = session_queue._prune_terminal_to_limit("default", keep=2)
|
||||
assert deleted == 3
|
||||
assert session_queue.get_queue_status("default").completed == 2
|
||||
|
||||
|
||||
# ---- list / pagination ----
|
||||
|
||||
|
||||
def test_list_queue_items_pagination(session_queue: SqlModelSessionQueue) -> None:
|
||||
ids = [_insert_raw(session_queue) for _ in range(5)]
|
||||
page = session_queue.list_queue_items("default", limit=2, priority=0)
|
||||
assert len(page.items) == 2
|
||||
assert page.has_more is True
|
||||
|
||||
next_page = session_queue.list_queue_items(
|
||||
"default", limit=2, priority=0, cursor=page.items[-1].item_id
|
||||
)
|
||||
assert len(next_page.items) == 2
|
||||
|
||||
# Make sure no item appears twice
|
||||
seen_ids = {i.item_id for i in page.items} | {i.item_id for i in next_page.items}
|
||||
assert seen_ids.issubset(set(ids))
|
||||
assert len(seen_ids) == 4
|
||||
|
||||
|
||||
def test_list_queue_items_filters_status_and_destination(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, destination="canvas", status="completed")
|
||||
_insert_raw(session_queue, destination="canvas", status="pending")
|
||||
_insert_raw(session_queue, destination="generate", status="completed")
|
||||
page = session_queue.list_queue_items(
|
||||
"default", limit=10, priority=0, status="completed", destination="canvas"
|
||||
)
|
||||
assert len(page.items) == 1
|
||||
|
||||
|
||||
def test_list_all_queue_items(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="generate")
|
||||
items = session_queue.list_all_queue_items("default", destination="canvas")
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
def test_get_queue_item_ids_ordering(session_queue: SqlModelSessionQueue) -> None:
|
||||
# Items inserted in the same millisecond may tie on created_at, so we only assert
|
||||
# set-equality and total_count. Ordering correctness is exercised by the SQL query
|
||||
# construction itself (covered by the production query path).
|
||||
ids = [_insert_raw(session_queue) for _ in range(3)]
|
||||
desc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Descending)
|
||||
asc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Ascending)
|
||||
assert desc.total_count == 3
|
||||
assert asc.total_count == 3
|
||||
assert set(desc.item_ids) == set(ids)
|
||||
assert set(asc.item_ids) == set(ids)
|
||||
|
||||
|
||||
def test_get_queue_item_ids_filters_user_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, user_id="alice")
|
||||
_insert_raw(session_queue, user_id="bob")
|
||||
result = session_queue.get_queue_item_ids("default", user_id="alice")
|
||||
assert result.total_count == 1
|
||||
|
||||
|
||||
# ---- aggregations ----
|
||||
|
||||
|
||||
def test_get_queue_status_counts(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
_insert_raw(session_queue, status="failed")
|
||||
_insert_raw(session_queue, status="canceled")
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 1
|
||||
assert status.completed == 1
|
||||
assert status.failed == 1
|
||||
assert status.canceled == 1
|
||||
assert status.total == 4
|
||||
|
||||
|
||||
def test_get_queue_status_user_id_hides_other_user_current(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, user_id="alice", status="in_progress")
|
||||
status = session_queue.get_queue_status("default", user_id="bob")
|
||||
# current item exists but belongs to alice — should be hidden for bob
|
||||
assert status.item_id is None
|
||||
|
||||
|
||||
def test_get_batch_status(session_queue: SqlModelSessionQueue) -> None:
|
||||
batch_id = str(uuid.uuid4())
|
||||
_insert_raw(session_queue, batch_id=batch_id, status="pending")
|
||||
_insert_raw(session_queue, batch_id=batch_id, status="completed")
|
||||
_insert_raw(session_queue, batch_id=str(uuid.uuid4()), status="completed")
|
||||
result = session_queue.get_batch_status("default", batch_id=batch_id)
|
||||
assert result.pending == 1
|
||||
assert result.completed == 1
|
||||
assert result.total == 2
|
||||
|
||||
|
||||
def test_get_counts_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas", status="pending")
|
||||
_insert_raw(session_queue, destination="canvas", status="completed")
|
||||
_insert_raw(session_queue, destination="generate", status="pending")
|
||||
result = session_queue.get_counts_by_destination("default", destination="canvas")
|
||||
assert result.pending == 1
|
||||
assert result.completed == 1
|
||||
assert result.total == 2
|
||||
|
||||
|
||||
# ---- retry ----
|
||||
|
||||
|
||||
def test_retry_items_by_id_skips_non_terminal(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
pending_id = _insert_raw(session_queue, status="pending")
|
||||
result = session_queue.retry_items_by_id("default", [pending_id])
|
||||
assert result.retried_item_ids == []
|
||||
|
||||
|
||||
def test_retry_items_by_id_clones_failed(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
# Use enqueue_batch so we get a valid `session` JSON, then fail it
|
||||
batch = Batch(graph=batch_graph, runs=1)
|
||||
enq = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
|
||||
item_id = enq.item_ids[0]
|
||||
session_queue.fail_queue_item(item_id, "ErrType", "ErrMsg", "trace")
|
||||
|
||||
retry = session_queue.retry_items_by_id("default", [item_id])
|
||||
assert retry.retried_item_ids == [item_id]
|
||||
# exactly one new pending item should now exist (the original is failed)
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 1
|
||||
assert status.failed == 1
|
||||
@@ -0,0 +1,89 @@
|
||||
"""Tests for SqlModelStylePresetRecordsStorage."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_common import (
|
||||
PresetData,
|
||||
StylePresetChanges,
|
||||
StylePresetNotFoundError,
|
||||
StylePresetWithoutId,
|
||||
)
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage(db: SqliteDatabase) -> SqlModelStylePresetRecordsStorage:
|
||||
return SqlModelStylePresetRecordsStorage(db=db)
|
||||
|
||||
|
||||
def _make_preset(name: str = "Test Preset", preset_type: str = "user") -> StylePresetWithoutId:
|
||||
return StylePresetWithoutId(
|
||||
name=name,
|
||||
preset_data=PresetData(positive_prompt="a cat", negative_prompt=""),
|
||||
type=preset_type,
|
||||
)
|
||||
|
||||
|
||||
def test_create_and_get(storage):
|
||||
preset = storage.create(_make_preset("My Preset"))
|
||||
assert preset.name == "My Preset"
|
||||
|
||||
fetched = storage.get(preset.id)
|
||||
assert fetched.name == "My Preset"
|
||||
|
||||
|
||||
def test_get_nonexistent(storage):
|
||||
with pytest.raises(StylePresetNotFoundError):
|
||||
storage.get("nonexistent")
|
||||
|
||||
|
||||
def test_update_name(storage):
|
||||
preset = storage.create(_make_preset("Original"))
|
||||
updated = storage.update(preset.id, StylePresetChanges(name="Updated", type=None))
|
||||
assert updated.name == "Updated"
|
||||
|
||||
|
||||
def test_update_preset_data(storage):
|
||||
preset = storage.create(_make_preset())
|
||||
new_data = PresetData(positive_prompt="a dog", negative_prompt="ugly")
|
||||
updated = storage.update(preset.id, StylePresetChanges(preset_data=new_data, type=None))
|
||||
assert updated.preset_data.positive_prompt == "a dog"
|
||||
|
||||
|
||||
def test_delete(storage):
|
||||
preset = storage.create(_make_preset())
|
||||
storage.delete(preset.id)
|
||||
with pytest.raises(StylePresetNotFoundError):
|
||||
storage.get(preset.id)
|
||||
|
||||
|
||||
def test_get_many(storage):
|
||||
storage.create(_make_preset("Preset A"))
|
||||
storage.create(_make_preset("Preset B"))
|
||||
results = storage.get_many()
|
||||
# Filter out any default presets
|
||||
user_presets = [r for r in results if r.type == "user"]
|
||||
assert len(user_presets) == 2
|
||||
|
||||
|
||||
def test_get_many_filter_by_type(storage):
|
||||
storage.create(_make_preset("User Preset", "user"))
|
||||
# There may be defaults loaded; just verify filtering works
|
||||
user_presets = storage.get_many(type="user")
|
||||
assert all(p.type == "user" for p in user_presets)
|
||||
|
||||
|
||||
def test_create_many(storage):
|
||||
presets = [_make_preset(f"Preset {i}") for i in range(3)]
|
||||
storage.create_many(presets)
|
||||
all_presets = storage.get_many(type="user")
|
||||
assert len(all_presets) >= 3
|
||||
|
||||
|
||||
def test_get_many_ordered_by_name(storage):
|
||||
storage.create(_make_preset("Zebra"))
|
||||
storage.create(_make_preset("Alpha"))
|
||||
results = storage.get_many(type="user")
|
||||
names = [r.name for r in results]
|
||||
assert names == sorted(names, key=str.lower)
|
||||
150
tests/app/services/test_sqlmodel_services/test_users_sqlmodel.py
Normal file
150
tests/app/services/test_sqlmodel_services/test_users_sqlmodel.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""Tests for UserServiceSqlModel."""
|
||||
|
||||
import pytest
|
||||
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.users.users_common import UserCreateRequest, UserUpdateRequest
|
||||
from invokeai.app.services.users.users_sqlmodel import UserServiceSqlModel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_service(db: SqliteDatabase) -> UserServiceSqlModel:
|
||||
return UserServiceSqlModel(db=db)
|
||||
|
||||
|
||||
def test_create_user(user_service: UserServiceSqlModel):
|
||||
user = user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test User", password="TestPassword123")
|
||||
)
|
||||
assert user.email == "test@example.com"
|
||||
assert user.display_name == "Test User"
|
||||
assert user.is_admin is False
|
||||
assert user.is_active is True
|
||||
|
||||
|
||||
def test_create_user_weak_password(user_service: UserServiceSqlModel):
|
||||
with pytest.raises(ValueError, match="at least 8 characters"):
|
||||
user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test", password="weak"),
|
||||
strict_password_checking=True,
|
||||
)
|
||||
|
||||
|
||||
def test_create_user_weak_password_non_strict(user_service: UserServiceSqlModel):
|
||||
user = user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test", password="weak"),
|
||||
strict_password_checking=False,
|
||||
)
|
||||
assert user.email == "test@example.com"
|
||||
|
||||
|
||||
def test_create_duplicate_user(user_service: UserServiceSqlModel):
|
||||
data = UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
|
||||
user_service.create(data)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
user_service.create(data)
|
||||
|
||||
|
||||
def test_get_user(user_service: UserServiceSqlModel):
|
||||
created = user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
|
||||
)
|
||||
fetched = user_service.get(created.user_id)
|
||||
assert fetched is not None
|
||||
assert fetched.user_id == created.user_id
|
||||
|
||||
|
||||
def test_get_nonexistent_user(user_service: UserServiceSqlModel):
|
||||
assert user_service.get("nonexistent-id") is None
|
||||
|
||||
|
||||
def test_get_user_by_email(user_service: UserServiceSqlModel):
|
||||
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
|
||||
fetched = user_service.get_by_email("test@example.com")
|
||||
assert fetched is not None
|
||||
assert fetched.email == "test@example.com"
|
||||
|
||||
|
||||
def test_update_user(user_service: UserServiceSqlModel):
|
||||
user = user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
|
||||
)
|
||||
updated = user_service.update(user.user_id, UserUpdateRequest(display_name="Updated", is_admin=True))
|
||||
assert updated.display_name == "Updated"
|
||||
assert updated.is_admin is True
|
||||
|
||||
|
||||
def test_delete_user(user_service: UserServiceSqlModel):
|
||||
user = user_service.create(
|
||||
UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123")
|
||||
)
|
||||
user_service.delete(user.user_id)
|
||||
assert user_service.get(user.user_id) is None
|
||||
|
||||
|
||||
def test_authenticate_valid(user_service: UserServiceSqlModel):
|
||||
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
|
||||
auth = user_service.authenticate("test@example.com", "TestPassword123")
|
||||
assert auth is not None
|
||||
assert auth.email == "test@example.com"
|
||||
assert auth.last_login_at is not None
|
||||
|
||||
|
||||
def test_authenticate_invalid_password(user_service: UserServiceSqlModel):
|
||||
user_service.create(UserCreateRequest(email="test@example.com", display_name="Test", password="TestPassword123"))
|
||||
assert user_service.authenticate("test@example.com", "WrongPassword") is None
|
||||
|
||||
|
||||
def test_authenticate_nonexistent(user_service: UserServiceSqlModel):
|
||||
assert user_service.authenticate("none@example.com", "TestPassword123") is None
|
||||
|
||||
|
||||
def test_has_admin(user_service: UserServiceSqlModel):
|
||||
assert user_service.has_admin() is False
|
||||
user_service.create(
|
||||
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
|
||||
)
|
||||
assert user_service.has_admin() is True
|
||||
|
||||
|
||||
def test_create_admin(user_service: UserServiceSqlModel):
|
||||
admin = user_service.create_admin(
|
||||
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123")
|
||||
)
|
||||
assert admin.is_admin is True
|
||||
|
||||
|
||||
def test_create_admin_when_exists(user_service: UserServiceSqlModel):
|
||||
user_service.create_admin(
|
||||
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123")
|
||||
)
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
user_service.create_admin(
|
||||
UserCreateRequest(email="admin2@example.com", display_name="Admin2", password="AdminPassword123")
|
||||
)
|
||||
|
||||
|
||||
def test_list_users(user_service: UserServiceSqlModel):
|
||||
for i in range(5):
|
||||
user_service.create(
|
||||
UserCreateRequest(email=f"test{i}@example.com", display_name=f"User {i}", password="TestPassword123")
|
||||
)
|
||||
# Migration 27 creates a 'system' user, so total = 5 + 1
|
||||
assert len(user_service.list_users()) == 6
|
||||
assert len(user_service.list_users(limit=2)) == 2
|
||||
|
||||
|
||||
def test_get_admin_email(user_service: UserServiceSqlModel):
|
||||
assert user_service.get_admin_email() is None
|
||||
user_service.create(
|
||||
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
|
||||
)
|
||||
assert user_service.get_admin_email() == "admin@example.com"
|
||||
|
||||
|
||||
def test_count_admins(user_service: UserServiceSqlModel):
|
||||
assert user_service.count_admins() == 0
|
||||
user_service.create(
|
||||
UserCreateRequest(email="admin@example.com", display_name="Admin", password="AdminPassword123", is_admin=True)
|
||||
)
|
||||
assert user_service.count_admins() == 1
|
||||
Reference in New Issue
Block a user