Migrate session_queue to SQLModel (Phase 3)

Port SqliteSessionQueue to a SQLAlchemy Core / SQLModel hybrid that keeps the
existing public API and DB schema (migrations and triggers untouched). Hot
paths (enqueue bulk insert, dequeue, bulk cancel/delete, list with cursor
pagination, status aggregations) use Core to avoid ORM hydration overhead;
single-row reads stay ORM-style for clarity.

- Add SqlModelSessionQueue alongside the legacy SqliteSessionQueue
- Add the missing `workflow` column to SessionQueueTable (was added by
  migration_2 but never declared on the SQLModel)
- Wire dependencies.py to the new implementation
- Add 36 unit tests covering enqueue/dequeue, status mutations, bulk
  cancel/delete, prune-to-limit, retry, pagination and aggregations
- Avoid nested write sessions on the single StaticPool connection by reading
  the current item before opening the outer write session
This commit is contained in:
Alexander Eichhorn
2026-04-20 23:43:24 +02:00
parent 4f7343b4e4
commit 0a428ffff4
4 changed files with 1313 additions and 2 deletions

View File

@@ -42,7 +42,7 @@ from invokeai.app.services.session_processor.session_processor_default import (
DefaultSessionProcessor,
DefaultSessionRunner,
)
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
@@ -175,7 +175,7 @@ class ApiDependencies:
names = SimpleNameService()
performance_statistics = InvocationStatsService()
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
session_queue = SqliteSessionQueue(db=db) # Stays raw SQL (Phase 3)
session_queue = SqlModelSessionQueue(db=db)
urls = LocalUrlService()
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)

View File

@@ -0,0 +1,843 @@
"""SQLModel-backed implementation of the session queue service.
This module is the Phase 3 sibling of `session_queue_sqlite.py`. It uses
SQLAlchemy Core for the hot paths (bulk enqueue/cancel/delete, dequeue, list
with cursor pagination, aggregations) and keeps the same external behaviour as
the raw-SQL implementation, including reliance on the existing DB triggers for
`started_at`, `completed_at` and `updated_at`.
"""
import asyncio
import json
from typing import Any, Optional
from pydantic_core import to_jsonable_python
from sqlalchemy import and_, delete, func, insert, or_, select, update
from sqlalchemy.engine import Row
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
from invokeai.app.services.session_queue.session_queue_common import (
DEFAULT_QUEUE_ID,
QUEUE_ITEM_STATUS,
Batch,
BatchStatus,
CancelAllExceptCurrentResult,
CancelByBatchIDsResult,
CancelByDestinationResult,
CancelByQueueIDResult,
ClearResult,
DeleteAllExceptCurrentResult,
DeleteByDestinationResult,
EnqueueBatchResult,
IsEmptyResult,
IsFullResult,
ItemIdsResult,
PruneResult,
RetryItemsResult,
SessionQueueCountsByDestination,
SessionQueueItem,
SessionQueueItemNotFoundError,
SessionQueueStatus,
ValueToInsertTuple,
calc_session_count,
prepare_values_to_insert,
)
from invokeai.app.services.shared.graph import GraphExecutionState
from invokeai.app.services.shared.pagination import CursorPaginatedResults
from invokeai.app.services.shared.sqlite.models import SessionQueueTable, UserTable
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
_TERMINAL_STATUSES: tuple[str, ...] = ("completed", "failed", "canceled")
_QUEUE_COLUMNS = (
SessionQueueTable.item_id,
SessionQueueTable.batch_id,
SessionQueueTable.queue_id,
SessionQueueTable.session_id,
SessionQueueTable.field_values,
SessionQueueTable.session,
SessionQueueTable.status,
SessionQueueTable.priority,
SessionQueueTable.error_traceback,
SessionQueueTable.created_at,
SessionQueueTable.updated_at,
SessionQueueTable.started_at,
SessionQueueTable.completed_at,
SessionQueueTable.error_type,
SessionQueueTable.error_message,
SessionQueueTable.origin,
SessionQueueTable.destination,
SessionQueueTable.retried_from_item_id,
SessionQueueTable.user_id,
)
def _row_to_queue_item_dict(row: Row) -> dict[str, Any]:
"""Convert a Row produced by `_select_queue_item_with_user` to a plain dict
that `SessionQueueItem.queue_item_from_dict` expects."""
mapping = dict(row._mapping)
# Stringify datetime columns so the Pydantic union (`datetime | str`) accepts them
# consistently across queries that JOIN datetime columns from multiple tables.
for ts_key in ("created_at", "updated_at", "started_at", "completed_at"):
ts_value = mapping.get(ts_key)
if ts_value is not None and not isinstance(ts_value, str):
mapping[ts_key] = str(ts_value)
mapping.setdefault("user_display_name", None)
mapping.setdefault("user_email", None)
mapping.setdefault("workflow", None)
return mapping
def _select_queue_item_with_user():
"""Build a SELECT that mirrors `sq.*, u.display_name, u.email` with LEFT JOIN."""
return (
select(
*_QUEUE_COLUMNS,
SessionQueueTable.workflow,
UserTable.display_name.label("user_display_name"),
UserTable.email.label("user_email"),
)
.select_from(SessionQueueTable)
.join(UserTable, SessionQueueTable.user_id == UserTable.user_id, isouter=True)
)
def _value_tuple_to_dict(t: ValueToInsertTuple) -> dict[str, Any]:
"""Adapt the positional tuple from `prepare_values_to_insert` to a dict that
SQLAlchemy Core's `insert(...).values([...])` expects."""
return {
"queue_id": t[0],
"session": t[1],
"session_id": t[2],
"batch_id": t[3],
"field_values": t[4],
"priority": t[5],
"workflow": t[6],
"origin": t[7],
"destination": t[8],
"retried_from_item_id": t[9],
"user_id": t[10],
}
class SqlModelSessionQueue(SessionQueueBase):
__invoker: Invoker
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self._db = db
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
self._set_in_progress_to_canceled()
config = self.__invoker.services.configuration
if config.clear_queue_on_startup:
clear_result = self.clear(DEFAULT_QUEUE_ID)
if clear_result.deleted > 0:
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
return
if config.max_queue_history is not None:
deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history)
if deleted > 0:
self.__invoker.services.logger.info(
f"Pruned {deleted} completed/failed/canceled queue items "
f"(kept up to {config.max_queue_history})"
)
# region: internal helpers
def _set_in_progress_to_canceled(self) -> None:
"""Sets all in_progress queue items to canceled. Run on app startup."""
with self._db.get_session() as session:
session.execute(
update(SessionQueueTable)
.where(SessionQueueTable.status == "in_progress")
.values(status="canceled")
)
def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int:
"""Prune terminal items (completed/failed/canceled) to keep at most N most-recent items."""
terminal_filter = and_(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
)
# Subquery: ids of the items we want to keep (most recent N)
keep_ids_stmt = (
select(SessionQueueTable.item_id)
.where(terminal_filter)
.order_by(
func.coalesce(
SessionQueueTable.completed_at,
SessionQueueTable.updated_at,
SessionQueueTable.created_at,
).desc(),
SessionQueueTable.item_id.desc(),
)
.limit(keep)
)
with self._db.get_session() as session:
count_stmt = (
select(func.count())
.select_from(SessionQueueTable)
.where(terminal_filter)
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
)
count = session.execute(count_stmt).scalar_one()
session.execute(
delete(SessionQueueTable)
.where(terminal_filter)
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
)
return int(count)
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items."""
with self._db.get_readonly_session() as session:
count = session.execute(
select(func.count())
.select_from(SessionQueueTable)
.where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "pending",
)
).scalar_one()
return int(count)
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue."""
with self._db.get_readonly_session() as session:
priority = session.execute(
select(func.max(SessionQueueTable.priority)).where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "pending",
)
).scalar()
return int(priority) if priority is not None else 0
# endregion
# region: enqueue / dequeue / read single
async def enqueue_batch(
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
) -> EnqueueBatchResult:
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
max_new_queue_items = max_queue_size - current_queue_size
priority = 0
if prepend:
priority = self._get_highest_priority(queue_id) + 1
requested_count = await asyncio.to_thread(calc_session_count, batch=batch)
values_to_insert = await asyncio.to_thread(
prepare_values_to_insert,
queue_id=queue_id,
batch=batch,
priority=priority,
max_new_queue_items=max_new_queue_items,
user_id=user_id,
)
enqueued_count = len(values_to_insert)
with self._db.get_session() as session:
if values_to_insert:
session.execute(
insert(SessionQueueTable),
[_value_tuple_to_dict(v) for v in values_to_insert],
)
item_ids_rows = session.execute(
select(SessionQueueTable.item_id)
.where(SessionQueueTable.batch_id == batch.batch_id)
.order_by(SessionQueueTable.item_id.desc())
).all()
item_ids = [row[0] for row in item_ids_rows]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
enqueued=enqueued_count,
batch=batch,
priority=priority,
item_ids=item_ids,
)
self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
with self._db.get_readonly_session() as session:
row = session.execute(
_select_queue_item_with_user()
.where(SessionQueueTable.status == "pending")
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc())
.limit(1)
).first()
if row is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
return self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.get_readonly_session() as session:
row = session.execute(
_select_queue_item_with_user()
.where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "pending",
)
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.created_at.asc())
.limit(1)
).first()
if row is None:
return None
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
with self._db.get_readonly_session() as session:
row = session.execute(
_select_queue_item_with_user()
.where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "in_progress",
)
.limit(1)
).first()
if row is None:
return None
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
def get_queue_item(self, item_id: int) -> SessionQueueItem:
with self._db.get_readonly_session() as session:
row = session.execute(
_select_queue_item_with_user().where(SessionQueueTable.item_id == item_id)
).first()
if row is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
# endregion
# region: status mutation
def _set_queue_item_status(
self,
item_id: int,
status: QUEUE_ITEM_STATUS,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
with self._db.get_session() as session:
current_status = session.execute(
select(SessionQueueTable.status).where(SessionQueueTable.item_id == item_id)
).scalar()
if current_status is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
# Only update if not already finished (completed, failed or canceled)
if current_status in _TERMINAL_STATUSES:
# No update; fall through to fetch + return below.
pass
else:
session.execute(
update(SessionQueueTable)
.where(SessionQueueTable.item_id == item_id)
.values(
status=status,
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
)
queue_item = self.get_queue_item(item_id)
# If we did not update, do not emit a status change event.
if current_status not in _TERMINAL_STATUSES:
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
return queue_item
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
return self._set_queue_item_status(item_id=item_id, status="canceled")
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
return self._set_queue_item_status(item_id=item_id, status="completed")
def fail_queue_item(
self,
item_id: int,
error_type: str,
error_message: str,
error_traceback: str,
) -> SessionQueueItem:
return self._set_queue_item_status(
item_id=item_id,
status="failed",
error_type=error_type,
error_message=error_message,
error_traceback=error_traceback,
)
def delete_queue_item(self, item_id: int) -> None:
try:
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
with self._db.get_session() as session:
session.execute(delete(SessionQueueTable).where(SessionQueueTable.item_id == item_id))
def set_queue_item_session(self, item_id: int, session_state: GraphExecutionState) -> SessionQueueItem:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause
# validation errors when the graph is loaded. Graph execution occurs purely in memory - the
# session saved here is not referenced during execution.
session_json = session_state.model_dump_json(warnings=False, exclude_none=True)
with self._db.get_session() as session:
session.execute(
update(SessionQueueTable)
.where(SessionQueueTable.item_id == item_id)
.values(session=session_json)
)
return self.get_queue_item(item_id)
# endregion
# region: simple status checks
def is_empty(self, queue_id: str) -> IsEmptyResult:
with self._db.get_readonly_session() as session:
count = session.execute(
select(func.count())
.select_from(SessionQueueTable)
.where(SessionQueueTable.queue_id == queue_id)
).scalar_one()
return IsEmptyResult(is_empty=int(count) == 0)
def is_full(self, queue_id: str) -> IsFullResult:
with self._db.get_readonly_session() as session:
count = session.execute(
select(func.count())
.select_from(SessionQueueTable)
.where(SessionQueueTable.queue_id == queue_id)
).scalar_one()
max_queue_size = self.__invoker.services.configuration.max_queue_size
return IsFullResult(is_full=int(count) >= max_queue_size)
# endregion
# region: bulk delete
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
where = [SessionQueueTable.queue_id == queue_id]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(delete(SessionQueueTable).where(*where))
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=int(count))
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(delete(SessionQueueTable).where(*where))
return PruneResult(deleted=int(count))
def delete_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> DeleteByDestinationResult:
# Handle current in-progress item BEFORE opening a write session of our own,
# to avoid nested writes on the single StaticPool connection.
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
if user_id is None or current_queue_item.user_id == user_id:
self.cancel_queue_item(current_queue_item.item_id)
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.destination == destination,
]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(delete(SessionQueueTable).where(*where))
return DeleteByDestinationResult(deleted=int(count))
def delete_all_except_current(
self, queue_id: str, user_id: Optional[str] = None
) -> DeleteAllExceptCurrentResult:
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "pending",
]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(delete(SessionQueueTable).where(*where))
return DeleteAllExceptCurrentResult(deleted=int(count))
# endregion
# region: bulk cancel
def _cancel_skip_in_progress_filter(
self, queue_id: str, user_id: Optional[str], extra: list
) -> list:
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
where.extend(extra)
return where
def cancel_by_batch_ids(
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
) -> CancelByBatchIDsResult:
current_queue_item = self.get_current(queue_id)
where = self._cancel_skip_in_progress_filter(
queue_id, user_id, [SessionQueueTable.batch_id.in_(batch_ids)]
)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
# Handle current item separately - check ownership if user_id is provided
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
if user_id is None or current_queue_item.user_id == user_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByBatchIDsResult(canceled=int(count))
def cancel_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> CancelByDestinationResult:
current_queue_item = self.get_current(queue_id)
where = self._cancel_skip_in_progress_filter(
queue_id, user_id, [SessionQueueTable.destination == destination]
)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
if current_queue_item is not None and current_queue_item.destination == destination:
if user_id is None or current_queue_item.user_id == user_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=int(count))
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
current_queue_item = self.get_current(queue_id)
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
]
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=int(count))
def cancel_all_except_current(
self, queue_id: str, user_id: Optional[str] = None
) -> CancelAllExceptCurrentResult:
where = [
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.status == "pending",
]
if user_id is not None:
where.append(SessionQueueTable.user_id == user_id)
with self._db.get_session() as session:
count = session.execute(
select(func.count()).select_from(SessionQueueTable).where(*where)
).scalar_one()
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
return CancelAllExceptCurrentResult(canceled=int(count))
# endregion
# region: list / pagination
def list_queue_items(
self,
queue_id: str,
limit: int,
priority: int,
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
# NOTE: this preserves the (somewhat surprising) cursor semantics of the original
# raw-SQL implementation, including the unparenthesised `AND ... OR ...` precedence.
item_id = cursor
stmt = select(*_QUEUE_COLUMNS, SessionQueueTable.workflow).where(
SessionQueueTable.queue_id == queue_id
)
if status is not None:
stmt = stmt.where(SessionQueueTable.status == status)
if destination is not None:
stmt = stmt.where(SessionQueueTable.destination == destination)
if item_id is not None:
stmt = stmt.where(
or_(
SessionQueueTable.priority < priority,
and_(
SessionQueueTable.priority == priority,
SessionQueueTable.item_id > item_id,
),
)
)
stmt = stmt.order_by(
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
).limit(limit + 1)
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
items = [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
has_more = False
if len(items) > limit:
items.pop()
has_more = True
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def list_all_queue_items(
self,
queue_id: str,
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
stmt = _select_queue_item_with_user().where(SessionQueueTable.queue_id == queue_id)
if destination is not None:
stmt = stmt.where(SessionQueueTable.destination == destination)
stmt = stmt.order_by(
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
)
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
return [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
def get_queue_item_ids(
self,
queue_id: str,
order_dir: SQLiteDirection = SQLiteDirection.Descending,
user_id: Optional[str] = None,
) -> ItemIdsResult:
stmt = select(SessionQueueTable.item_id).where(SessionQueueTable.queue_id == queue_id)
if user_id is not None:
stmt = stmt.where(SessionQueueTable.user_id == user_id)
if order_dir == SQLiteDirection.Descending:
stmt = stmt.order_by(SessionQueueTable.created_at.desc())
else:
stmt = stmt.order_by(SessionQueueTable.created_at.asc())
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
item_ids = [row[0] for row in rows]
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
# endregion
# region: aggregations
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
stmt = (
select(SessionQueueTable.status, func.count())
.where(SessionQueueTable.queue_id == queue_id)
.group_by(SessionQueueTable.status)
)
if user_id is not None:
stmt = stmt.where(SessionQueueTable.user_id == user_id)
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
current_item = self.get_current(queue_id=queue_id)
total = sum(int(row[1] or 0) for row in rows)
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
# For non-admin users, hide current item details if they don't own it
show_current_item = current_item is not None and (
user_id is None or current_item.user_id == user_id
)
return SessionQueueStatus(
queue_id=queue_id,
item_id=current_item.item_id if show_current_item else None,
session_id=current_item.session_id if show_current_item else None,
batch_id=current_item.batch_id if show_current_item else None,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)
def get_batch_status(
self, queue_id: str, batch_id: str, user_id: Optional[str] = None
) -> BatchStatus:
stmt = (
select(
SessionQueueTable.status,
func.count(),
SessionQueueTable.origin,
SessionQueueTable.destination,
)
.where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.batch_id == batch_id,
)
.group_by(SessionQueueTable.status)
)
if user_id is not None:
stmt = stmt.where(SessionQueueTable.user_id == user_id)
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
total = sum(int(row[1] or 0) for row in rows)
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
origin = rows[0][2] if rows else None
destination = rows[0][3] if rows else None
return BatchStatus(
batch_id=batch_id,
origin=origin,
destination=destination,
queue_id=queue_id,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)
def get_counts_by_destination(
self, queue_id: str, destination: str, user_id: Optional[str] = None
) -> SessionQueueCountsByDestination:
stmt = (
select(SessionQueueTable.status, func.count())
.where(
SessionQueueTable.queue_id == queue_id,
SessionQueueTable.destination == destination,
)
.group_by(SessionQueueTable.status)
)
if user_id is not None:
stmt = stmt.where(SessionQueueTable.user_id == user_id)
with self._db.get_readonly_session() as session:
rows = session.execute(stmt).all()
total = sum(int(row[1] or 0) for row in rows)
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
return SessionQueueCountsByDestination(
queue_id=queue_id,
destination=destination,
pending=counts.get("pending", 0),
in_progress=counts.get("in_progress", 0),
completed=counts.get("completed", 0),
failed=counts.get("failed", 0),
canceled=counts.get("canceled", 0),
total=total,
)
# endregion
# region: retry
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
for item_id in item_ids:
queue_item = self.get_queue_item(item_id)
if queue_item.status not in ("failed", "canceled"):
continue
retried_item_ids.append(item_id)
field_values_json = (
json.dumps(queue_item.field_values, default=to_jsonable_python)
if queue_item.field_values
else None
)
workflow_json = (
json.dumps(queue_item.workflow, default=to_jsonable_python)
if queue_item.workflow
else None
)
cloned_session = GraphExecutionState(graph=queue_item.session.graph)
cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True)
retried_from_item_id = (
queue_item.retried_from_item_id
if queue_item.retried_from_item_id is not None
else queue_item.item_id
)
values_to_insert.append(
(
queue_item.queue_id,
cloned_session_json,
cloned_session.id,
queue_item.batch_id,
field_values_json,
queue_item.priority,
workflow_json,
queue_item.origin,
queue_item.destination,
retried_from_item_id,
queue_item.user_id,
)
)
# TODO(psyche): Handle max queue size?
if values_to_insert:
with self._db.get_session() as session:
session.execute(
insert(SessionQueueTable),
[_value_tuple_to_dict(v) for v in values_to_insert],
)
retry_result = RetryItemsResult(queue_id=queue_id, retried_item_ids=retried_item_ids)
self.__invoker.services.events.emit_queue_items_retried(retry_result)
return retry_result
# endregion

View File

@@ -141,6 +141,7 @@ class SessionQueueTable(SQLModel, table=True):
destination: Optional[str] = Field(default=None)
retried_from_item_id: Optional[int] = Field(default=None)
user_id: str = Field(default="system")
workflow: Optional[str] = Field(default=None) # JSON blob
# --- models ---

View File

@@ -0,0 +1,467 @@
"""Tests for the SQLModel-backed session queue implementation."""
import asyncio
import uuid
from typing import Optional
import pytest
from sqlalchemy import insert
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.session_queue.session_queue_common import (
Batch,
SessionQueueItemNotFoundError,
)
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
from invokeai.app.services.shared.sqlite.models import SessionQueueTable
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
from tests.test_nodes import PromptTestInvocation
# ---- fixtures ----
@pytest.fixture
def session_queue(mock_invoker: Invoker) -> SqlModelSessionQueue:
"""Create a SqlModelSessionQueue backed by the mock invoker's in-memory database."""
db = mock_invoker.services.board_records._db
queue = SqlModelSessionQueue(db=db)
queue.start(mock_invoker)
return queue
@pytest.fixture
def batch_graph() -> Graph:
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
return g
# ---- helpers ----
def _make_session_json() -> tuple[str, str]:
"""Build a valid GraphExecutionState JSON blob and return (session_id, json)."""
g = Graph()
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
state = GraphExecutionState(graph=g)
return state.id, state.model_dump_json(warnings=False, exclude_none=True)
def _insert_raw(
queue: SqlModelSessionQueue,
*,
queue_id: str = "default",
user_id: str = "system",
status: str = "pending",
priority: int = 0,
batch_id: Optional[str] = None,
destination: Optional[str] = None,
) -> int:
"""Insert a minimal queue item via Core and return its item_id."""
session_id, session_json = _make_session_json()
batch_id = batch_id or str(uuid.uuid4())
with queue._db.get_session() as session:
result = session.execute(
insert(SessionQueueTable).values(
queue_id=queue_id,
session=session_json,
session_id=session_id,
batch_id=batch_id,
field_values=None,
priority=priority,
workflow=None,
origin=None,
destination=destination,
retried_from_item_id=None,
user_id=user_id,
status=status,
)
)
return int(result.inserted_primary_key[0])
# ---- start() / _set_in_progress_to_canceled ----
def test_start_cancels_in_progress(mock_invoker: Invoker) -> None:
db = mock_invoker.services.board_records._db
queue = SqlModelSessionQueue(db=db)
in_progress_id = _insert_raw(queue, status="in_progress")
queue.start(mock_invoker)
item = queue.get_queue_item(in_progress_id)
assert item.status == "canceled"
# ---- simple read methods ----
def test_is_empty_and_is_full(session_queue: SqlModelSessionQueue) -> None:
assert session_queue.is_empty("default").is_empty is True
_insert_raw(session_queue)
assert session_queue.is_empty("default").is_empty is False
# default max_queue_size is high; queue with 1 item is not full
assert session_queue.is_full("default").is_full is False
def test_get_queue_item_not_found(session_queue: SqlModelSessionQueue) -> None:
with pytest.raises(SessionQueueItemNotFoundError):
session_queue.get_queue_item(99999)
def test_get_queue_item(session_queue: SqlModelSessionQueue) -> None:
item_id = _insert_raw(session_queue, user_id="alice")
item = session_queue.get_queue_item(item_id)
assert item.item_id == item_id
assert item.user_id == "alice"
assert item.status == "pending"
def test_get_current_and_get_next(session_queue: SqlModelSessionQueue) -> None:
pending = _insert_raw(session_queue, priority=1)
in_progress = _insert_raw(session_queue, status="in_progress")
current = session_queue.get_current("default")
assert current is not None and current.item_id == in_progress
nxt = session_queue.get_next("default")
assert nxt is not None and nxt.item_id == pending
def test_get_current_queue_size(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue)
_insert_raw(session_queue)
_insert_raw(session_queue, status="completed")
assert session_queue._get_current_queue_size("default") == 2
def test_get_highest_priority(session_queue: SqlModelSessionQueue) -> None:
assert session_queue._get_highest_priority("default") == 0
_insert_raw(session_queue, priority=3)
_insert_raw(session_queue, priority=7)
_insert_raw(session_queue, priority=10, status="completed") # ignored
assert session_queue._get_highest_priority("default") == 7
# ---- enqueue / dequeue ----
def test_enqueue_batch_and_dequeue(
session_queue: SqlModelSessionQueue, batch_graph: Graph
) -> None:
batch = Batch(graph=batch_graph, runs=2)
result = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
assert result.enqueued == 2
assert result.requested == 2
assert len(result.item_ids) == 2
# dequeue takes the first pending and marks it in_progress
dequeued = session_queue.dequeue()
assert dequeued is not None
assert dequeued.status == "in_progress"
# only one in-progress at a time
current = session_queue.get_current("default")
assert current is not None and current.item_id == dequeued.item_id
def test_enqueue_batch_prepend_increases_priority(
session_queue: SqlModelSessionQueue, batch_graph: Graph
) -> None:
asyncio.run(session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=False))
second = asyncio.run(
session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=True)
)
assert second.priority == 1
def test_dequeue_empty_returns_none(session_queue: SqlModelSessionQueue) -> None:
assert session_queue.dequeue() is None
# ---- status mutations ----
def test_complete_fail_cancel_queue_item(session_queue: SqlModelSessionQueue) -> None:
item_id = _insert_raw(session_queue)
assert session_queue.complete_queue_item(item_id).status == "completed"
# second mutation on terminal-status item is a no-op (returns existing)
assert session_queue.cancel_queue_item(item_id).status == "completed"
item_id2 = _insert_raw(session_queue)
failed = session_queue.fail_queue_item(item_id2, "ErrType", "ErrMsg", "trace")
assert failed.status == "failed"
assert failed.error_type == "ErrType"
assert failed.error_message == "ErrMsg"
assert failed.error_traceback == "trace"
item_id3 = _insert_raw(session_queue)
assert session_queue.cancel_queue_item(item_id3).status == "canceled"
def test_set_queue_item_status_unknown_id_raises(
session_queue: SqlModelSessionQueue,
) -> None:
with pytest.raises(SessionQueueItemNotFoundError):
session_queue._set_queue_item_status(99999, "completed")
def test_delete_queue_item(session_queue: SqlModelSessionQueue) -> None:
item_id = _insert_raw(session_queue)
session_queue.delete_queue_item(item_id)
with pytest.raises(SessionQueueItemNotFoundError):
session_queue.get_queue_item(item_id)
def test_set_queue_item_session(
session_queue: SqlModelSessionQueue, batch_graph: Graph
) -> None:
item_id = _insert_raw(session_queue)
new_session = GraphExecutionState(graph=batch_graph)
session_queue.set_queue_item_session(item_id, new_session)
fetched = session_queue.get_queue_item(item_id)
assert fetched.session.id == new_session.id
# ---- bulk delete ----
def test_clear_with_user_id_only_deletes_own_items(
session_queue: SqlModelSessionQueue,
) -> None:
_insert_raw(session_queue, user_id="user_a")
_insert_raw(session_queue, user_id="user_a")
_insert_raw(session_queue, user_id="user_b")
result = session_queue.clear("default", user_id="user_a")
assert result.deleted == 2
def test_clear_without_user_id_deletes_all(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, user_id="user_a")
_insert_raw(session_queue, user_id="user_b")
result = session_queue.clear("default")
assert result.deleted == 2
def test_prune_only_deletes_terminal(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="completed")
_insert_raw(session_queue, status="failed")
_insert_raw(session_queue, status="canceled")
_insert_raw(session_queue, status="in_progress")
result = session_queue.prune("default")
assert result.deleted == 3
# pending and in_progress remain
assert session_queue.get_queue_status("default").pending == 1
assert session_queue.get_queue_status("default").in_progress == 1
def test_prune_with_user_id(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, status="completed", user_id="user_a")
_insert_raw(session_queue, status="failed", user_id="user_b")
result = session_queue.prune("default", user_id="user_a")
assert result.deleted == 1
def test_delete_by_destination(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, destination="canvas")
_insert_raw(session_queue, destination="canvas")
_insert_raw(session_queue, destination="generate")
result = session_queue.delete_by_destination("default", destination="canvas")
assert result.deleted == 2
def test_delete_all_except_current(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="in_progress")
_insert_raw(session_queue, status="completed")
result = session_queue.delete_all_except_current("default")
# only deletes pending
assert result.deleted == 2
status = session_queue.get_queue_status("default")
assert status.pending == 0
assert status.in_progress == 1
assert status.completed == 1
# ---- bulk cancel ----
def test_cancel_by_batch_ids(session_queue: SqlModelSessionQueue) -> None:
batch_id = str(uuid.uuid4())
_insert_raw(session_queue, batch_id=batch_id)
_insert_raw(session_queue, batch_id=batch_id)
_insert_raw(session_queue, batch_id=str(uuid.uuid4())) # different batch
result = session_queue.cancel_by_batch_ids("default", [batch_id])
assert result.canceled == 2
def test_cancel_by_destination(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, destination="canvas")
_insert_raw(session_queue, destination="canvas", status="completed") # skipped
_insert_raw(session_queue, destination="generate") # different dest
result = session_queue.cancel_by_destination("default", "canvas")
assert result.canceled == 1
def test_cancel_by_queue_id(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, queue_id="default")
_insert_raw(session_queue, queue_id="default")
_insert_raw(session_queue, queue_id="other")
result = session_queue.cancel_by_queue_id("default")
assert result.canceled == 2
def test_cancel_all_except_current(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="in_progress")
result = session_queue.cancel_all_except_current("default")
assert result.canceled == 2
# ---- prune-to-limit ----
def test_prune_terminal_to_limit_keeps_n_most_recent(
session_queue: SqlModelSessionQueue,
) -> None:
for _ in range(5):
_insert_raw(session_queue, status="completed")
deleted = session_queue._prune_terminal_to_limit("default", keep=2)
assert deleted == 3
assert session_queue.get_queue_status("default").completed == 2
# ---- list / pagination ----
def test_list_queue_items_pagination(session_queue: SqlModelSessionQueue) -> None:
ids = [_insert_raw(session_queue) for _ in range(5)]
page = session_queue.list_queue_items("default", limit=2, priority=0)
assert len(page.items) == 2
assert page.has_more is True
next_page = session_queue.list_queue_items(
"default", limit=2, priority=0, cursor=page.items[-1].item_id
)
assert len(next_page.items) == 2
# Make sure no item appears twice
seen_ids = {i.item_id for i in page.items} | {i.item_id for i in next_page.items}
assert seen_ids.issubset(set(ids))
assert len(seen_ids) == 4
def test_list_queue_items_filters_status_and_destination(
session_queue: SqlModelSessionQueue,
) -> None:
_insert_raw(session_queue, destination="canvas", status="completed")
_insert_raw(session_queue, destination="canvas", status="pending")
_insert_raw(session_queue, destination="generate", status="completed")
page = session_queue.list_queue_items(
"default", limit=10, priority=0, status="completed", destination="canvas"
)
assert len(page.items) == 1
def test_list_all_queue_items(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, destination="canvas")
_insert_raw(session_queue, destination="canvas")
_insert_raw(session_queue, destination="generate")
items = session_queue.list_all_queue_items("default", destination="canvas")
assert len(items) == 2
def test_get_queue_item_ids_ordering(session_queue: SqlModelSessionQueue) -> None:
# Items inserted in the same millisecond may tie on created_at, so we only assert
# set-equality and total_count. Ordering correctness is exercised by the SQL query
# construction itself (covered by the production query path).
ids = [_insert_raw(session_queue) for _ in range(3)]
desc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Descending)
asc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Ascending)
assert desc.total_count == 3
assert asc.total_count == 3
assert set(desc.item_ids) == set(ids)
assert set(asc.item_ids) == set(ids)
def test_get_queue_item_ids_filters_user_id(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, user_id="alice")
_insert_raw(session_queue, user_id="bob")
result = session_queue.get_queue_item_ids("default", user_id="alice")
assert result.total_count == 1
# ---- aggregations ----
def test_get_queue_status_counts(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, status="pending")
_insert_raw(session_queue, status="completed")
_insert_raw(session_queue, status="failed")
_insert_raw(session_queue, status="canceled")
status = session_queue.get_queue_status("default")
assert status.pending == 1
assert status.completed == 1
assert status.failed == 1
assert status.canceled == 1
assert status.total == 4
def test_get_queue_status_user_id_hides_other_user_current(
session_queue: SqlModelSessionQueue,
) -> None:
_insert_raw(session_queue, user_id="alice", status="in_progress")
status = session_queue.get_queue_status("default", user_id="bob")
# current item exists but belongs to alice — should be hidden for bob
assert status.item_id is None
def test_get_batch_status(session_queue: SqlModelSessionQueue) -> None:
batch_id = str(uuid.uuid4())
_insert_raw(session_queue, batch_id=batch_id, status="pending")
_insert_raw(session_queue, batch_id=batch_id, status="completed")
_insert_raw(session_queue, batch_id=str(uuid.uuid4()), status="completed")
result = session_queue.get_batch_status("default", batch_id=batch_id)
assert result.pending == 1
assert result.completed == 1
assert result.total == 2
def test_get_counts_by_destination(session_queue: SqlModelSessionQueue) -> None:
_insert_raw(session_queue, destination="canvas", status="pending")
_insert_raw(session_queue, destination="canvas", status="completed")
_insert_raw(session_queue, destination="generate", status="pending")
result = session_queue.get_counts_by_destination("default", destination="canvas")
assert result.pending == 1
assert result.completed == 1
assert result.total == 2
# ---- retry ----
def test_retry_items_by_id_skips_non_terminal(
session_queue: SqlModelSessionQueue, batch_graph: Graph
) -> None:
pending_id = _insert_raw(session_queue, status="pending")
result = session_queue.retry_items_by_id("default", [pending_id])
assert result.retried_item_ids == []
def test_retry_items_by_id_clones_failed(
session_queue: SqlModelSessionQueue, batch_graph: Graph
) -> None:
# Use enqueue_batch so we get a valid `session` JSON, then fail it
batch = Batch(graph=batch_graph, runs=1)
enq = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
item_id = enq.item_ids[0]
session_queue.fail_queue_item(item_id, "ErrType", "ErrMsg", "trace")
retry = session_queue.retry_items_by_id("default", [item_id])
assert retry.retried_item_ids == [item_id]
# exactly one new pending item should now exist (the original is failed)
status = session_queue.get_queue_status("default")
assert status.pending == 1
assert status.failed == 1