mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Migrate session_queue to SQLModel (Phase 3)
Port SqliteSessionQueue to a SQLAlchemy Core / SQLModel hybrid that keeps the existing public API and DB schema (migrations and triggers untouched). Hot paths (enqueue bulk insert, dequeue, bulk cancel/delete, list with cursor pagination, status aggregations) use Core to avoid ORM hydration overhead; single-row reads stay ORM-style for clarity. - Add SqlModelSessionQueue alongside the legacy SqliteSessionQueue - Add the missing `workflow` column to SessionQueueTable (was added by migration_2 but never declared on the SQLModel) - Wire dependencies.py to the new implementation - Add 36 unit tests covering enqueue/dequeue, status mutations, bulk cancel/delete, prune-to-limit, retry, pagination and aggregations - Avoid nested write sessions on the single StaticPool connection by reading the current item before opening the outer write session
This commit is contained in:
@@ -42,7 +42,7 @@ from invokeai.app.services.session_processor.session_processor_default import (
|
||||
DefaultSessionProcessor,
|
||||
DefaultSessionRunner,
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue
|
||||
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
|
||||
from invokeai.app.services.shared.sqlite.sqlite_util import init_db
|
||||
from invokeai.app.services.style_preset_images.style_preset_images_disk import StylePresetImageFileStorageDisk
|
||||
from invokeai.app.services.style_preset_records.style_preset_records_sqlmodel import SqlModelStylePresetRecordsStorage
|
||||
@@ -175,7 +175,7 @@ class ApiDependencies:
|
||||
names = SimpleNameService()
|
||||
performance_statistics = InvocationStatsService()
|
||||
session_processor = DefaultSessionProcessor(session_runner=DefaultSessionRunner())
|
||||
session_queue = SqliteSessionQueue(db=db) # Stays raw SQL (Phase 3)
|
||||
session_queue = SqlModelSessionQueue(db=db)
|
||||
urls = LocalUrlService()
|
||||
workflow_records = SqlModelWorkflowRecordsStorage(db=db)
|
||||
style_preset_records = SqlModelStylePresetRecordsStorage(db=db)
|
||||
|
||||
843
invokeai/app/services/session_queue/session_queue_sqlmodel.py
Normal file
843
invokeai/app/services/session_queue/session_queue_sqlmodel.py
Normal file
@@ -0,0 +1,843 @@
|
||||
"""SQLModel-backed implementation of the session queue service.
|
||||
|
||||
This module is the Phase 3 sibling of `session_queue_sqlite.py`. It uses
|
||||
SQLAlchemy Core for the hot paths (bulk enqueue/cancel/delete, dequeue, list
|
||||
with cursor pagination, aggregations) and keeps the same external behaviour as
|
||||
the raw-SQL implementation, including reliance on the existing DB triggers for
|
||||
`started_at`, `completed_at` and `updated_at`.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic_core import to_jsonable_python
|
||||
from sqlalchemy import and_, delete, func, insert, or_, select, update
|
||||
from sqlalchemy.engine import Row
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
DEFAULT_QUEUE_ID,
|
||||
QUEUE_ITEM_STATUS,
|
||||
Batch,
|
||||
BatchStatus,
|
||||
CancelAllExceptCurrentResult,
|
||||
CancelByBatchIDsResult,
|
||||
CancelByDestinationResult,
|
||||
CancelByQueueIDResult,
|
||||
ClearResult,
|
||||
DeleteAllExceptCurrentResult,
|
||||
DeleteByDestinationResult,
|
||||
EnqueueBatchResult,
|
||||
IsEmptyResult,
|
||||
IsFullResult,
|
||||
ItemIdsResult,
|
||||
PruneResult,
|
||||
RetryItemsResult,
|
||||
SessionQueueCountsByDestination,
|
||||
SessionQueueItem,
|
||||
SessionQueueItemNotFoundError,
|
||||
SessionQueueStatus,
|
||||
ValueToInsertTuple,
|
||||
calc_session_count,
|
||||
prepare_values_to_insert,
|
||||
)
|
||||
from invokeai.app.services.shared.graph import GraphExecutionState
|
||||
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
||||
from invokeai.app.services.shared.sqlite.models import SessionQueueTable, UserTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
|
||||
_TERMINAL_STATUSES: tuple[str, ...] = ("completed", "failed", "canceled")
|
||||
|
||||
_QUEUE_COLUMNS = (
|
||||
SessionQueueTable.item_id,
|
||||
SessionQueueTable.batch_id,
|
||||
SessionQueueTable.queue_id,
|
||||
SessionQueueTable.session_id,
|
||||
SessionQueueTable.field_values,
|
||||
SessionQueueTable.session,
|
||||
SessionQueueTable.status,
|
||||
SessionQueueTable.priority,
|
||||
SessionQueueTable.error_traceback,
|
||||
SessionQueueTable.created_at,
|
||||
SessionQueueTable.updated_at,
|
||||
SessionQueueTable.started_at,
|
||||
SessionQueueTable.completed_at,
|
||||
SessionQueueTable.error_type,
|
||||
SessionQueueTable.error_message,
|
||||
SessionQueueTable.origin,
|
||||
SessionQueueTable.destination,
|
||||
SessionQueueTable.retried_from_item_id,
|
||||
SessionQueueTable.user_id,
|
||||
)
|
||||
|
||||
|
||||
def _row_to_queue_item_dict(row: Row) -> dict[str, Any]:
|
||||
"""Convert a Row produced by `_select_queue_item_with_user` to a plain dict
|
||||
that `SessionQueueItem.queue_item_from_dict` expects."""
|
||||
mapping = dict(row._mapping)
|
||||
# Stringify datetime columns so the Pydantic union (`datetime | str`) accepts them
|
||||
# consistently across queries that JOIN datetime columns from multiple tables.
|
||||
for ts_key in ("created_at", "updated_at", "started_at", "completed_at"):
|
||||
ts_value = mapping.get(ts_key)
|
||||
if ts_value is not None and not isinstance(ts_value, str):
|
||||
mapping[ts_key] = str(ts_value)
|
||||
mapping.setdefault("user_display_name", None)
|
||||
mapping.setdefault("user_email", None)
|
||||
mapping.setdefault("workflow", None)
|
||||
return mapping
|
||||
|
||||
|
||||
def _select_queue_item_with_user():
|
||||
"""Build a SELECT that mirrors `sq.*, u.display_name, u.email` with LEFT JOIN."""
|
||||
return (
|
||||
select(
|
||||
*_QUEUE_COLUMNS,
|
||||
SessionQueueTable.workflow,
|
||||
UserTable.display_name.label("user_display_name"),
|
||||
UserTable.email.label("user_email"),
|
||||
)
|
||||
.select_from(SessionQueueTable)
|
||||
.join(UserTable, SessionQueueTable.user_id == UserTable.user_id, isouter=True)
|
||||
)
|
||||
|
||||
|
||||
def _value_tuple_to_dict(t: ValueToInsertTuple) -> dict[str, Any]:
|
||||
"""Adapt the positional tuple from `prepare_values_to_insert` to a dict that
|
||||
SQLAlchemy Core's `insert(...).values([...])` expects."""
|
||||
return {
|
||||
"queue_id": t[0],
|
||||
"session": t[1],
|
||||
"session_id": t[2],
|
||||
"batch_id": t[3],
|
||||
"field_values": t[4],
|
||||
"priority": t[5],
|
||||
"workflow": t[6],
|
||||
"origin": t[7],
|
||||
"destination": t[8],
|
||||
"retried_from_item_id": t[9],
|
||||
"user_id": t[10],
|
||||
}
|
||||
|
||||
|
||||
class SqlModelSessionQueue(SessionQueueBase):
|
||||
__invoker: Invoker
|
||||
|
||||
def __init__(self, db: SqliteDatabase) -> None:
|
||||
super().__init__()
|
||||
self._db = db
|
||||
|
||||
def start(self, invoker: Invoker) -> None:
|
||||
self.__invoker = invoker
|
||||
self._set_in_progress_to_canceled()
|
||||
config = self.__invoker.services.configuration
|
||||
if config.clear_queue_on_startup:
|
||||
clear_result = self.clear(DEFAULT_QUEUE_ID)
|
||||
if clear_result.deleted > 0:
|
||||
self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items")
|
||||
return
|
||||
|
||||
if config.max_queue_history is not None:
|
||||
deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history)
|
||||
if deleted > 0:
|
||||
self.__invoker.services.logger.info(
|
||||
f"Pruned {deleted} completed/failed/canceled queue items "
|
||||
f"(kept up to {config.max_queue_history})"
|
||||
)
|
||||
|
||||
# region: internal helpers
|
||||
|
||||
def _set_in_progress_to_canceled(self) -> None:
|
||||
"""Sets all in_progress queue items to canceled. Run on app startup."""
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.status == "in_progress")
|
||||
.values(status="canceled")
|
||||
)
|
||||
|
||||
def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int:
|
||||
"""Prune terminal items (completed/failed/canceled) to keep at most N most-recent items."""
|
||||
terminal_filter = and_(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
|
||||
)
|
||||
# Subquery: ids of the items we want to keep (most recent N)
|
||||
keep_ids_stmt = (
|
||||
select(SessionQueueTable.item_id)
|
||||
.where(terminal_filter)
|
||||
.order_by(
|
||||
func.coalesce(
|
||||
SessionQueueTable.completed_at,
|
||||
SessionQueueTable.updated_at,
|
||||
SessionQueueTable.created_at,
|
||||
).desc(),
|
||||
SessionQueueTable.item_id.desc(),
|
||||
)
|
||||
.limit(keep)
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count_stmt = (
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(terminal_filter)
|
||||
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
|
||||
)
|
||||
count = session.execute(count_stmt).scalar_one()
|
||||
session.execute(
|
||||
delete(SessionQueueTable)
|
||||
.where(terminal_filter)
|
||||
.where(~SessionQueueTable.item_id.in_(keep_ids_stmt))
|
||||
)
|
||||
return int(count)
|
||||
|
||||
def _get_current_queue_size(self, queue_id: str) -> int:
|
||||
"""Gets the current number of pending queue items."""
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
).scalar_one()
|
||||
return int(count)
|
||||
|
||||
def _get_highest_priority(self, queue_id: str) -> int:
|
||||
"""Gets the highest priority value in the queue."""
|
||||
with self._db.get_readonly_session() as session:
|
||||
priority = session.execute(
|
||||
select(func.max(SessionQueueTable.priority)).where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
).scalar()
|
||||
return int(priority) if priority is not None else 0
|
||||
|
||||
# endregion
|
||||
|
||||
# region: enqueue / dequeue / read single
|
||||
|
||||
async def enqueue_batch(
|
||||
self, queue_id: str, batch: Batch, prepend: bool, user_id: str = "system"
|
||||
) -> EnqueueBatchResult:
|
||||
current_queue_size = self._get_current_queue_size(queue_id)
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
max_new_queue_items = max_queue_size - current_queue_size
|
||||
|
||||
priority = 0
|
||||
if prepend:
|
||||
priority = self._get_highest_priority(queue_id) + 1
|
||||
|
||||
requested_count = await asyncio.to_thread(calc_session_count, batch=batch)
|
||||
values_to_insert = await asyncio.to_thread(
|
||||
prepare_values_to_insert,
|
||||
queue_id=queue_id,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
max_new_queue_items=max_new_queue_items,
|
||||
user_id=user_id,
|
||||
)
|
||||
enqueued_count = len(values_to_insert)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
if values_to_insert:
|
||||
session.execute(
|
||||
insert(SessionQueueTable),
|
||||
[_value_tuple_to_dict(v) for v in values_to_insert],
|
||||
)
|
||||
item_ids_rows = session.execute(
|
||||
select(SessionQueueTable.item_id)
|
||||
.where(SessionQueueTable.batch_id == batch.batch_id)
|
||||
.order_by(SessionQueueTable.item_id.desc())
|
||||
).all()
|
||||
item_ids = [row[0] for row in item_ids_rows]
|
||||
|
||||
enqueue_result = EnqueueBatchResult(
|
||||
queue_id=queue_id,
|
||||
requested=requested_count,
|
||||
enqueued=enqueued_count,
|
||||
batch=batch,
|
||||
priority=priority,
|
||||
item_ids=item_ids,
|
||||
)
|
||||
self.__invoker.services.events.emit_batch_enqueued(enqueue_result, user_id=user_id)
|
||||
return enqueue_result
|
||||
|
||||
def dequeue(self) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(SessionQueueTable.status == "pending")
|
||||
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc())
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
queue_item = SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
return self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
||||
|
||||
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
)
|
||||
.order_by(SessionQueueTable.priority.desc(), SessionQueueTable.created_at.asc())
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user()
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "in_progress",
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
if row is None:
|
||||
return None
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
with self._db.get_readonly_session() as session:
|
||||
row = session.execute(
|
||||
_select_queue_item_with_user().where(SessionQueueTable.item_id == item_id)
|
||||
).first()
|
||||
if row is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
return SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(row))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: status mutation
|
||||
|
||||
def _set_queue_item_status(
|
||||
self,
|
||||
item_id: int,
|
||||
status: QUEUE_ITEM_STATUS,
|
||||
error_type: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
error_traceback: Optional[str] = None,
|
||||
) -> SessionQueueItem:
|
||||
with self._db.get_session() as session:
|
||||
current_status = session.execute(
|
||||
select(SessionQueueTable.status).where(SessionQueueTable.item_id == item_id)
|
||||
).scalar()
|
||||
if current_status is None:
|
||||
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
||||
|
||||
# Only update if not already finished (completed, failed or canceled)
|
||||
if current_status in _TERMINAL_STATUSES:
|
||||
# No update; fall through to fetch + return below.
|
||||
pass
|
||||
else:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.item_id == item_id)
|
||||
.values(
|
||||
status=status,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
)
|
||||
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
|
||||
# If we did not update, do not emit a status change event.
|
||||
if current_status not in _TERMINAL_STATUSES:
|
||||
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
||||
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
||||
self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status)
|
||||
return queue_item
|
||||
|
||||
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(item_id=item_id, status="canceled")
|
||||
|
||||
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(item_id=item_id, status="completed")
|
||||
|
||||
def fail_queue_item(
|
||||
self,
|
||||
item_id: int,
|
||||
error_type: str,
|
||||
error_message: str,
|
||||
error_traceback: str,
|
||||
) -> SessionQueueItem:
|
||||
return self._set_queue_item_status(
|
||||
item_id=item_id,
|
||||
status="failed",
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
error_traceback=error_traceback,
|
||||
)
|
||||
|
||||
def delete_queue_item(self, item_id: int) -> None:
|
||||
try:
|
||||
self.cancel_queue_item(item_id)
|
||||
except SessionQueueItemNotFoundError:
|
||||
pass
|
||||
with self._db.get_session() as session:
|
||||
session.execute(delete(SessionQueueTable).where(SessionQueueTable.item_id == item_id))
|
||||
|
||||
def set_queue_item_session(self, item_id: int, session_state: GraphExecutionState) -> SessionQueueItem:
|
||||
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause
|
||||
# validation errors when the graph is loaded. Graph execution occurs purely in memory - the
|
||||
# session saved here is not referenced during execution.
|
||||
session_json = session_state.model_dump_json(warnings=False, exclude_none=True)
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
update(SessionQueueTable)
|
||||
.where(SessionQueueTable.item_id == item_id)
|
||||
.values(session=session_json)
|
||||
)
|
||||
return self.get_queue_item(item_id)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: simple status checks
|
||||
|
||||
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
).scalar_one()
|
||||
return IsEmptyResult(is_empty=int(count) == 0)
|
||||
|
||||
def is_full(self, queue_id: str) -> IsFullResult:
|
||||
with self._db.get_readonly_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count())
|
||||
.select_from(SessionQueueTable)
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
).scalar_one()
|
||||
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
||||
return IsFullResult(is_full=int(count) >= max_queue_size)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: bulk delete
|
||||
|
||||
def clear(self, queue_id: str, user_id: Optional[str] = None) -> ClearResult:
|
||||
where = [SessionQueueTable.queue_id == queue_id]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
||||
return ClearResult(deleted=int(count))
|
||||
|
||||
def prune(self, queue_id: str, user_id: Optional[str] = None) -> PruneResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.in_(_TERMINAL_STATUSES),
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return PruneResult(deleted=int(count))
|
||||
|
||||
def delete_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> DeleteByDestinationResult:
|
||||
# Handle current in-progress item BEFORE opening a write session of our own,
|
||||
# to avoid nested writes on the single StaticPool connection.
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self.cancel_queue_item(current_queue_item.item_id)
|
||||
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.destination == destination,
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return DeleteByDestinationResult(deleted=int(count))
|
||||
|
||||
def delete_all_except_current(
|
||||
self, queue_id: str, user_id: Optional[str] = None
|
||||
) -> DeleteAllExceptCurrentResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(delete(SessionQueueTable).where(*where))
|
||||
return DeleteAllExceptCurrentResult(deleted=int(count))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: bulk cancel
|
||||
|
||||
def _cancel_skip_in_progress_filter(
|
||||
self, queue_id: str, user_id: Optional[str], extra: list
|
||||
) -> list:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
where.extend(extra)
|
||||
return where
|
||||
|
||||
def cancel_by_batch_ids(
|
||||
self, queue_id: str, batch_ids: list[str], user_id: Optional[str] = None
|
||||
) -> CancelByBatchIDsResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = self._cancel_skip_in_progress_filter(
|
||||
queue_id, user_id, [SessionQueueTable.batch_id.in_(batch_ids)]
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
# Handle current item separately - check ownership if user_id is provided
|
||||
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByBatchIDsResult(canceled=int(count))
|
||||
|
||||
def cancel_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> CancelByDestinationResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = self._cancel_skip_in_progress_filter(
|
||||
queue_id, user_id, [SessionQueueTable.destination == destination]
|
||||
)
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
if current_queue_item is not None and current_queue_item.destination == destination:
|
||||
if user_id is None or current_queue_item.user_id == user_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
|
||||
return CancelByDestinationResult(canceled=int(count))
|
||||
|
||||
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
||||
current_queue_item = self.get_current(queue_id)
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status.notin_(("canceled", "completed", "failed", "in_progress")),
|
||||
]
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
|
||||
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
||||
self._set_queue_item_status(current_queue_item.item_id, "canceled")
|
||||
return CancelByQueueIDResult(canceled=int(count))
|
||||
|
||||
def cancel_all_except_current(
|
||||
self, queue_id: str, user_id: Optional[str] = None
|
||||
) -> CancelAllExceptCurrentResult:
|
||||
where = [
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.status == "pending",
|
||||
]
|
||||
if user_id is not None:
|
||||
where.append(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_session() as session:
|
||||
count = session.execute(
|
||||
select(func.count()).select_from(SessionQueueTable).where(*where)
|
||||
).scalar_one()
|
||||
session.execute(update(SessionQueueTable).where(*where).values(status="canceled"))
|
||||
return CancelAllExceptCurrentResult(canceled=int(count))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: list / pagination
|
||||
|
||||
def list_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
limit: int,
|
||||
priority: int,
|
||||
cursor: Optional[int] = None,
|
||||
status: Optional[QUEUE_ITEM_STATUS] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> CursorPaginatedResults[SessionQueueItem]:
|
||||
# NOTE: this preserves the (somewhat surprising) cursor semantics of the original
|
||||
# raw-SQL implementation, including the unparenthesised `AND ... OR ...` precedence.
|
||||
item_id = cursor
|
||||
|
||||
stmt = select(*_QUEUE_COLUMNS, SessionQueueTable.workflow).where(
|
||||
SessionQueueTable.queue_id == queue_id
|
||||
)
|
||||
if status is not None:
|
||||
stmt = stmt.where(SessionQueueTable.status == status)
|
||||
if destination is not None:
|
||||
stmt = stmt.where(SessionQueueTable.destination == destination)
|
||||
if item_id is not None:
|
||||
stmt = stmt.where(
|
||||
or_(
|
||||
SessionQueueTable.priority < priority,
|
||||
and_(
|
||||
SessionQueueTable.priority == priority,
|
||||
SessionQueueTable.item_id > item_id,
|
||||
),
|
||||
)
|
||||
)
|
||||
stmt = stmt.order_by(
|
||||
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
|
||||
).limit(limit + 1)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
items = [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
|
||||
has_more = False
|
||||
if len(items) > limit:
|
||||
items.pop()
|
||||
has_more = True
|
||||
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
||||
|
||||
def list_all_queue_items(
|
||||
self,
|
||||
queue_id: str,
|
||||
destination: Optional[str] = None,
|
||||
) -> list[SessionQueueItem]:
|
||||
stmt = _select_queue_item_with_user().where(SessionQueueTable.queue_id == queue_id)
|
||||
if destination is not None:
|
||||
stmt = stmt.where(SessionQueueTable.destination == destination)
|
||||
stmt = stmt.order_by(
|
||||
SessionQueueTable.priority.desc(), SessionQueueTable.item_id.asc()
|
||||
)
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
return [SessionQueueItem.queue_item_from_dict(_row_to_queue_item_dict(r)) for r in rows]
|
||||
|
||||
def get_queue_item_ids(
|
||||
self,
|
||||
queue_id: str,
|
||||
order_dir: SQLiteDirection = SQLiteDirection.Descending,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ItemIdsResult:
|
||||
stmt = select(SessionQueueTable.item_id).where(SessionQueueTable.queue_id == queue_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
if order_dir == SQLiteDirection.Descending:
|
||||
stmt = stmt.order_by(SessionQueueTable.created_at.desc())
|
||||
else:
|
||||
stmt = stmt.order_by(SessionQueueTable.created_at.asc())
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
item_ids = [row[0] for row in rows]
|
||||
return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids))
|
||||
|
||||
# endregion
|
||||
|
||||
# region: aggregations
|
||||
|
||||
def get_queue_status(self, queue_id: str, user_id: Optional[str] = None) -> SessionQueueStatus:
|
||||
stmt = (
|
||||
select(SessionQueueTable.status, func.count())
|
||||
.where(SessionQueueTable.queue_id == queue_id)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
current_item = self.get_current(queue_id=queue_id)
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
|
||||
# For non-admin users, hide current item details if they don't own it
|
||||
show_current_item = current_item is not None and (
|
||||
user_id is None or current_item.user_id == user_id
|
||||
)
|
||||
|
||||
return SessionQueueStatus(
|
||||
queue_id=queue_id,
|
||||
item_id=current_item.item_id if show_current_item else None,
|
||||
session_id=current_item.session_id if show_current_item else None,
|
||||
batch_id=current_item.batch_id if show_current_item else None,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_batch_status(
|
||||
self, queue_id: str, batch_id: str, user_id: Optional[str] = None
|
||||
) -> BatchStatus:
|
||||
stmt = (
|
||||
select(
|
||||
SessionQueueTable.status,
|
||||
func.count(),
|
||||
SessionQueueTable.origin,
|
||||
SessionQueueTable.destination,
|
||||
)
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.batch_id == batch_id,
|
||||
)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
origin = rows[0][2] if rows else None
|
||||
destination = rows[0][3] if rows else None
|
||||
|
||||
return BatchStatus(
|
||||
batch_id=batch_id,
|
||||
origin=origin,
|
||||
destination=destination,
|
||||
queue_id=queue_id,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
def get_counts_by_destination(
|
||||
self, queue_id: str, destination: str, user_id: Optional[str] = None
|
||||
) -> SessionQueueCountsByDestination:
|
||||
stmt = (
|
||||
select(SessionQueueTable.status, func.count())
|
||||
.where(
|
||||
SessionQueueTable.queue_id == queue_id,
|
||||
SessionQueueTable.destination == destination,
|
||||
)
|
||||
.group_by(SessionQueueTable.status)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(SessionQueueTable.user_id == user_id)
|
||||
|
||||
with self._db.get_readonly_session() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
total = sum(int(row[1] or 0) for row in rows)
|
||||
counts: dict[str, int] = {row[0]: int(row[1]) for row in rows}
|
||||
|
||||
return SessionQueueCountsByDestination(
|
||||
queue_id=queue_id,
|
||||
destination=destination,
|
||||
pending=counts.get("pending", 0),
|
||||
in_progress=counts.get("in_progress", 0),
|
||||
completed=counts.get("completed", 0),
|
||||
failed=counts.get("failed", 0),
|
||||
canceled=counts.get("canceled", 0),
|
||||
total=total,
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
# region: retry
|
||||
|
||||
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
|
||||
values_to_insert: list[ValueToInsertTuple] = []
|
||||
retried_item_ids: list[int] = []
|
||||
|
||||
for item_id in item_ids:
|
||||
queue_item = self.get_queue_item(item_id)
|
||||
if queue_item.status not in ("failed", "canceled"):
|
||||
continue
|
||||
retried_item_ids.append(item_id)
|
||||
|
||||
field_values_json = (
|
||||
json.dumps(queue_item.field_values, default=to_jsonable_python)
|
||||
if queue_item.field_values
|
||||
else None
|
||||
)
|
||||
workflow_json = (
|
||||
json.dumps(queue_item.workflow, default=to_jsonable_python)
|
||||
if queue_item.workflow
|
||||
else None
|
||||
)
|
||||
cloned_session = GraphExecutionState(graph=queue_item.session.graph)
|
||||
cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True)
|
||||
|
||||
retried_from_item_id = (
|
||||
queue_item.retried_from_item_id
|
||||
if queue_item.retried_from_item_id is not None
|
||||
else queue_item.item_id
|
||||
)
|
||||
|
||||
values_to_insert.append(
|
||||
(
|
||||
queue_item.queue_id,
|
||||
cloned_session_json,
|
||||
cloned_session.id,
|
||||
queue_item.batch_id,
|
||||
field_values_json,
|
||||
queue_item.priority,
|
||||
workflow_json,
|
||||
queue_item.origin,
|
||||
queue_item.destination,
|
||||
retried_from_item_id,
|
||||
queue_item.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# TODO(psyche): Handle max queue size?
|
||||
if values_to_insert:
|
||||
with self._db.get_session() as session:
|
||||
session.execute(
|
||||
insert(SessionQueueTable),
|
||||
[_value_tuple_to_dict(v) for v in values_to_insert],
|
||||
)
|
||||
|
||||
retry_result = RetryItemsResult(queue_id=queue_id, retried_item_ids=retried_item_ids)
|
||||
self.__invoker.services.events.emit_queue_items_retried(retry_result)
|
||||
return retry_result
|
||||
|
||||
# endregion
|
||||
@@ -141,6 +141,7 @@ class SessionQueueTable(SQLModel, table=True):
|
||||
destination: Optional[str] = Field(default=None)
|
||||
retried_from_item_id: Optional[int] = Field(default=None)
|
||||
user_id: str = Field(default="system")
|
||||
workflow: Optional[str] = Field(default=None) # JSON blob
|
||||
|
||||
|
||||
# --- models ---
|
||||
|
||||
@@ -0,0 +1,467 @@
|
||||
"""Tests for the SQLModel-backed session queue implementation."""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import insert
|
||||
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.session_queue.session_queue_common import (
|
||||
Batch,
|
||||
SessionQueueItemNotFoundError,
|
||||
)
|
||||
from invokeai.app.services.session_queue.session_queue_sqlmodel import SqlModelSessionQueue
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
|
||||
from invokeai.app.services.shared.sqlite.models import SessionQueueTable
|
||||
from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection
|
||||
from tests.test_nodes import PromptTestInvocation
|
||||
|
||||
|
||||
# ---- fixtures ----
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_queue(mock_invoker: Invoker) -> SqlModelSessionQueue:
|
||||
"""Create a SqlModelSessionQueue backed by the mock invoker's in-memory database."""
|
||||
db = mock_invoker.services.board_records._db
|
||||
queue = SqlModelSessionQueue(db=db)
|
||||
queue.start(mock_invoker)
|
||||
return queue
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_graph() -> Graph:
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
|
||||
return g
|
||||
|
||||
|
||||
# ---- helpers ----
|
||||
|
||||
|
||||
def _make_session_json() -> tuple[str, str]:
|
||||
"""Build a valid GraphExecutionState JSON blob and return (session_id, json)."""
|
||||
g = Graph()
|
||||
g.add_node(PromptTestInvocation(id="1", prompt="Chevy"))
|
||||
state = GraphExecutionState(graph=g)
|
||||
return state.id, state.model_dump_json(warnings=False, exclude_none=True)
|
||||
|
||||
|
||||
def _insert_raw(
|
||||
queue: SqlModelSessionQueue,
|
||||
*,
|
||||
queue_id: str = "default",
|
||||
user_id: str = "system",
|
||||
status: str = "pending",
|
||||
priority: int = 0,
|
||||
batch_id: Optional[str] = None,
|
||||
destination: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Insert a minimal queue item via Core and return its item_id."""
|
||||
session_id, session_json = _make_session_json()
|
||||
batch_id = batch_id or str(uuid.uuid4())
|
||||
with queue._db.get_session() as session:
|
||||
result = session.execute(
|
||||
insert(SessionQueueTable).values(
|
||||
queue_id=queue_id,
|
||||
session=session_json,
|
||||
session_id=session_id,
|
||||
batch_id=batch_id,
|
||||
field_values=None,
|
||||
priority=priority,
|
||||
workflow=None,
|
||||
origin=None,
|
||||
destination=destination,
|
||||
retried_from_item_id=None,
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
return int(result.inserted_primary_key[0])
|
||||
|
||||
|
||||
# ---- start() / _set_in_progress_to_canceled ----
|
||||
|
||||
|
||||
def test_start_cancels_in_progress(mock_invoker: Invoker) -> None:
|
||||
db = mock_invoker.services.board_records._db
|
||||
queue = SqlModelSessionQueue(db=db)
|
||||
in_progress_id = _insert_raw(queue, status="in_progress")
|
||||
queue.start(mock_invoker)
|
||||
item = queue.get_queue_item(in_progress_id)
|
||||
assert item.status == "canceled"
|
||||
|
||||
|
||||
# ---- simple read methods ----
|
||||
|
||||
|
||||
def test_is_empty_and_is_full(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue.is_empty("default").is_empty is True
|
||||
_insert_raw(session_queue)
|
||||
assert session_queue.is_empty("default").is_empty is False
|
||||
# default max_queue_size is high; queue with 1 item is not full
|
||||
assert session_queue.is_full("default").is_full is False
|
||||
|
||||
|
||||
def test_get_queue_item_not_found(session_queue: SqlModelSessionQueue) -> None:
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue.get_queue_item(99999)
|
||||
|
||||
|
||||
def test_get_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue, user_id="alice")
|
||||
item = session_queue.get_queue_item(item_id)
|
||||
assert item.item_id == item_id
|
||||
assert item.user_id == "alice"
|
||||
assert item.status == "pending"
|
||||
|
||||
|
||||
def test_get_current_and_get_next(session_queue: SqlModelSessionQueue) -> None:
|
||||
pending = _insert_raw(session_queue, priority=1)
|
||||
in_progress = _insert_raw(session_queue, status="in_progress")
|
||||
current = session_queue.get_current("default")
|
||||
assert current is not None and current.item_id == in_progress
|
||||
nxt = session_queue.get_next("default")
|
||||
assert nxt is not None and nxt.item_id == pending
|
||||
|
||||
|
||||
def test_get_current_queue_size(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue)
|
||||
_insert_raw(session_queue)
|
||||
_insert_raw(session_queue, status="completed")
|
||||
assert session_queue._get_current_queue_size("default") == 2
|
||||
|
||||
|
||||
def test_get_highest_priority(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue._get_highest_priority("default") == 0
|
||||
_insert_raw(session_queue, priority=3)
|
||||
_insert_raw(session_queue, priority=7)
|
||||
_insert_raw(session_queue, priority=10, status="completed") # ignored
|
||||
assert session_queue._get_highest_priority("default") == 7
|
||||
|
||||
|
||||
# ---- enqueue / dequeue ----
|
||||
|
||||
|
||||
def test_enqueue_batch_and_dequeue(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
batch = Batch(graph=batch_graph, runs=2)
|
||||
result = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
|
||||
assert result.enqueued == 2
|
||||
assert result.requested == 2
|
||||
assert len(result.item_ids) == 2
|
||||
|
||||
# dequeue takes the first pending and marks it in_progress
|
||||
dequeued = session_queue.dequeue()
|
||||
assert dequeued is not None
|
||||
assert dequeued.status == "in_progress"
|
||||
|
||||
# only one in-progress at a time
|
||||
current = session_queue.get_current("default")
|
||||
assert current is not None and current.item_id == dequeued.item_id
|
||||
|
||||
|
||||
def test_enqueue_batch_prepend_increases_priority(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
asyncio.run(session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=False))
|
||||
second = asyncio.run(
|
||||
session_queue.enqueue_batch("default", Batch(graph=batch_graph), prepend=True)
|
||||
)
|
||||
assert second.priority == 1
|
||||
|
||||
|
||||
def test_dequeue_empty_returns_none(session_queue: SqlModelSessionQueue) -> None:
|
||||
assert session_queue.dequeue() is None
|
||||
|
||||
|
||||
# ---- status mutations ----
|
||||
|
||||
|
||||
def test_complete_fail_cancel_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
assert session_queue.complete_queue_item(item_id).status == "completed"
|
||||
# second mutation on terminal-status item is a no-op (returns existing)
|
||||
assert session_queue.cancel_queue_item(item_id).status == "completed"
|
||||
|
||||
item_id2 = _insert_raw(session_queue)
|
||||
failed = session_queue.fail_queue_item(item_id2, "ErrType", "ErrMsg", "trace")
|
||||
assert failed.status == "failed"
|
||||
assert failed.error_type == "ErrType"
|
||||
assert failed.error_message == "ErrMsg"
|
||||
assert failed.error_traceback == "trace"
|
||||
|
||||
item_id3 = _insert_raw(session_queue)
|
||||
assert session_queue.cancel_queue_item(item_id3).status == "canceled"
|
||||
|
||||
|
||||
def test_set_queue_item_status_unknown_id_raises(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue._set_queue_item_status(99999, "completed")
|
||||
|
||||
|
||||
def test_delete_queue_item(session_queue: SqlModelSessionQueue) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
session_queue.delete_queue_item(item_id)
|
||||
with pytest.raises(SessionQueueItemNotFoundError):
|
||||
session_queue.get_queue_item(item_id)
|
||||
|
||||
|
||||
def test_set_queue_item_session(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
item_id = _insert_raw(session_queue)
|
||||
new_session = GraphExecutionState(graph=batch_graph)
|
||||
session_queue.set_queue_item_session(item_id, new_session)
|
||||
fetched = session_queue.get_queue_item(item_id)
|
||||
assert fetched.session.id == new_session.id
|
||||
|
||||
|
||||
# ---- bulk delete ----
|
||||
|
||||
|
||||
def test_clear_with_user_id_only_deletes_own_items(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_b")
|
||||
result = session_queue.clear("default", user_id="user_a")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_clear_without_user_id_deletes_all(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, user_id="user_a")
|
||||
_insert_raw(session_queue, user_id="user_b")
|
||||
result = session_queue.clear("default")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_prune_only_deletes_terminal(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
_insert_raw(session_queue, status="failed")
|
||||
_insert_raw(session_queue, status="canceled")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
result = session_queue.prune("default")
|
||||
assert result.deleted == 3
|
||||
# pending and in_progress remain
|
||||
assert session_queue.get_queue_status("default").pending == 1
|
||||
assert session_queue.get_queue_status("default").in_progress == 1
|
||||
|
||||
|
||||
def test_prune_with_user_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="completed", user_id="user_a")
|
||||
_insert_raw(session_queue, status="failed", user_id="user_b")
|
||||
result = session_queue.prune("default", user_id="user_a")
|
||||
assert result.deleted == 1
|
||||
|
||||
|
||||
def test_delete_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="generate")
|
||||
result = session_queue.delete_by_destination("default", destination="canvas")
|
||||
assert result.deleted == 2
|
||||
|
||||
|
||||
def test_delete_all_except_current(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
result = session_queue.delete_all_except_current("default")
|
||||
# only deletes pending
|
||||
assert result.deleted == 2
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 0
|
||||
assert status.in_progress == 1
|
||||
assert status.completed == 1
|
||||
|
||||
|
||||
# ---- bulk cancel ----
|
||||
|
||||
|
||||
def test_cancel_by_batch_ids(session_queue: SqlModelSessionQueue) -> None:
|
||||
batch_id = str(uuid.uuid4())
|
||||
_insert_raw(session_queue, batch_id=batch_id)
|
||||
_insert_raw(session_queue, batch_id=batch_id)
|
||||
_insert_raw(session_queue, batch_id=str(uuid.uuid4())) # different batch
|
||||
result = session_queue.cancel_by_batch_ids("default", [batch_id])
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
def test_cancel_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas", status="completed") # skipped
|
||||
_insert_raw(session_queue, destination="generate") # different dest
|
||||
result = session_queue.cancel_by_destination("default", "canvas")
|
||||
assert result.canceled == 1
|
||||
|
||||
|
||||
def test_cancel_by_queue_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, queue_id="default")
|
||||
_insert_raw(session_queue, queue_id="default")
|
||||
_insert_raw(session_queue, queue_id="other")
|
||||
result = session_queue.cancel_by_queue_id("default")
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
def test_cancel_all_except_current(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="in_progress")
|
||||
result = session_queue.cancel_all_except_current("default")
|
||||
assert result.canceled == 2
|
||||
|
||||
|
||||
# ---- prune-to-limit ----
|
||||
|
||||
|
||||
def test_prune_terminal_to_limit_keeps_n_most_recent(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
for _ in range(5):
|
||||
_insert_raw(session_queue, status="completed")
|
||||
deleted = session_queue._prune_terminal_to_limit("default", keep=2)
|
||||
assert deleted == 3
|
||||
assert session_queue.get_queue_status("default").completed == 2
|
||||
|
||||
|
||||
# ---- list / pagination ----
|
||||
|
||||
|
||||
def test_list_queue_items_pagination(session_queue: SqlModelSessionQueue) -> None:
|
||||
ids = [_insert_raw(session_queue) for _ in range(5)]
|
||||
page = session_queue.list_queue_items("default", limit=2, priority=0)
|
||||
assert len(page.items) == 2
|
||||
assert page.has_more is True
|
||||
|
||||
next_page = session_queue.list_queue_items(
|
||||
"default", limit=2, priority=0, cursor=page.items[-1].item_id
|
||||
)
|
||||
assert len(next_page.items) == 2
|
||||
|
||||
# Make sure no item appears twice
|
||||
seen_ids = {i.item_id for i in page.items} | {i.item_id for i in next_page.items}
|
||||
assert seen_ids.issubset(set(ids))
|
||||
assert len(seen_ids) == 4
|
||||
|
||||
|
||||
def test_list_queue_items_filters_status_and_destination(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, destination="canvas", status="completed")
|
||||
_insert_raw(session_queue, destination="canvas", status="pending")
|
||||
_insert_raw(session_queue, destination="generate", status="completed")
|
||||
page = session_queue.list_queue_items(
|
||||
"default", limit=10, priority=0, status="completed", destination="canvas"
|
||||
)
|
||||
assert len(page.items) == 1
|
||||
|
||||
|
||||
def test_list_all_queue_items(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="canvas")
|
||||
_insert_raw(session_queue, destination="generate")
|
||||
items = session_queue.list_all_queue_items("default", destination="canvas")
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
def test_get_queue_item_ids_ordering(session_queue: SqlModelSessionQueue) -> None:
|
||||
# Items inserted in the same millisecond may tie on created_at, so we only assert
|
||||
# set-equality and total_count. Ordering correctness is exercised by the SQL query
|
||||
# construction itself (covered by the production query path).
|
||||
ids = [_insert_raw(session_queue) for _ in range(3)]
|
||||
desc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Descending)
|
||||
asc = session_queue.get_queue_item_ids("default", order_dir=SQLiteDirection.Ascending)
|
||||
assert desc.total_count == 3
|
||||
assert asc.total_count == 3
|
||||
assert set(desc.item_ids) == set(ids)
|
||||
assert set(asc.item_ids) == set(ids)
|
||||
|
||||
|
||||
def test_get_queue_item_ids_filters_user_id(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, user_id="alice")
|
||||
_insert_raw(session_queue, user_id="bob")
|
||||
result = session_queue.get_queue_item_ids("default", user_id="alice")
|
||||
assert result.total_count == 1
|
||||
|
||||
|
||||
# ---- aggregations ----
|
||||
|
||||
|
||||
def test_get_queue_status_counts(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, status="pending")
|
||||
_insert_raw(session_queue, status="completed")
|
||||
_insert_raw(session_queue, status="failed")
|
||||
_insert_raw(session_queue, status="canceled")
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 1
|
||||
assert status.completed == 1
|
||||
assert status.failed == 1
|
||||
assert status.canceled == 1
|
||||
assert status.total == 4
|
||||
|
||||
|
||||
def test_get_queue_status_user_id_hides_other_user_current(
|
||||
session_queue: SqlModelSessionQueue,
|
||||
) -> None:
|
||||
_insert_raw(session_queue, user_id="alice", status="in_progress")
|
||||
status = session_queue.get_queue_status("default", user_id="bob")
|
||||
# current item exists but belongs to alice — should be hidden for bob
|
||||
assert status.item_id is None
|
||||
|
||||
|
||||
def test_get_batch_status(session_queue: SqlModelSessionQueue) -> None:
|
||||
batch_id = str(uuid.uuid4())
|
||||
_insert_raw(session_queue, batch_id=batch_id, status="pending")
|
||||
_insert_raw(session_queue, batch_id=batch_id, status="completed")
|
||||
_insert_raw(session_queue, batch_id=str(uuid.uuid4()), status="completed")
|
||||
result = session_queue.get_batch_status("default", batch_id=batch_id)
|
||||
assert result.pending == 1
|
||||
assert result.completed == 1
|
||||
assert result.total == 2
|
||||
|
||||
|
||||
def test_get_counts_by_destination(session_queue: SqlModelSessionQueue) -> None:
|
||||
_insert_raw(session_queue, destination="canvas", status="pending")
|
||||
_insert_raw(session_queue, destination="canvas", status="completed")
|
||||
_insert_raw(session_queue, destination="generate", status="pending")
|
||||
result = session_queue.get_counts_by_destination("default", destination="canvas")
|
||||
assert result.pending == 1
|
||||
assert result.completed == 1
|
||||
assert result.total == 2
|
||||
|
||||
|
||||
# ---- retry ----
|
||||
|
||||
|
||||
def test_retry_items_by_id_skips_non_terminal(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
pending_id = _insert_raw(session_queue, status="pending")
|
||||
result = session_queue.retry_items_by_id("default", [pending_id])
|
||||
assert result.retried_item_ids == []
|
||||
|
||||
|
||||
def test_retry_items_by_id_clones_failed(
|
||||
session_queue: SqlModelSessionQueue, batch_graph: Graph
|
||||
) -> None:
|
||||
# Use enqueue_batch so we get a valid `session` JSON, then fail it
|
||||
batch = Batch(graph=batch_graph, runs=1)
|
||||
enq = asyncio.run(session_queue.enqueue_batch("default", batch, prepend=False))
|
||||
item_id = enq.item_ids[0]
|
||||
session_queue.fail_queue_item(item_id, "ErrType", "ErrMsg", "trace")
|
||||
|
||||
retry = session_queue.retry_items_by_id("default", [item_id])
|
||||
assert retry.retried_item_ids == [item_id]
|
||||
# exactly one new pending item should now exist (the original is failed)
|
||||
status = session_queue.get_queue_status("default")
|
||||
assert status.pending == 1
|
||||
assert status.failed == 1
|
||||
Reference in New Issue
Block a user