import asyncio import json import sqlite3 from typing import Optional, Union, cast from pydantic_core import to_jsonable_python from invokeai.app.services.invoker import Invoker from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase from invokeai.app.services.session_queue.session_queue_common import ( DEFAULT_QUEUE_ID, QUEUE_ITEM_STATUS, Batch, BatchStatus, CancelAllExceptCurrentResult, CancelByBatchIDsResult, CancelByDestinationResult, CancelByQueueIDResult, ClearResult, DeleteAllExceptCurrentResult, DeleteByDestinationResult, EnqueueBatchResult, IsEmptyResult, IsFullResult, ItemIdsResult, PruneResult, RetryItemsResult, SessionQueueCountsByDestination, SessionQueueItem, SessionQueueItemNotFoundError, SessionQueueStatus, ValueToInsertTuple, calc_session_count, prepare_values_to_insert, ) from invokeai.app.services.shared.graph import GraphExecutionState from invokeai.app.services.shared.pagination import CursorPaginatedResults from invokeai.app.services.shared.sqlite.sqlite_common import SQLiteDirection from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase class SqliteSessionQueue(SessionQueueBase): __invoker: Invoker def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() if self.__invoker.services.configuration.clear_queue_on_startup: clear_result = self.clear(DEFAULT_QUEUE_ID) if clear_result.deleted > 0: self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") def __init__(self, db: SqliteDatabase) -> None: super().__init__() self._db = db def _set_in_progress_to_canceled(self) -> None: """ Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue. This is necessary because the invoker may have been killed while processing a queue item. """ with self._db.transaction() as cursor: cursor.execute( """--sql UPDATE session_queue SET status = 'canceled' WHERE status = 'in_progress'; """ ) def _get_current_queue_size(self, queue_id: str) -> int: """Gets the current number of pending queue items""" with self._db.transaction() as cursor: cursor.execute( """--sql SELECT count(*) FROM session_queue WHERE queue_id = ? AND status = 'pending' """, (queue_id,), ) count = cast(int, cursor.fetchone()[0]) return count def _get_highest_priority(self, queue_id: str) -> int: """Gets the highest priority value in the queue""" with self._db.transaction() as cursor: cursor.execute( """--sql SELECT MAX(priority) FROM session_queue WHERE queue_id = ? AND status = 'pending' """, (queue_id,), ) priority = cast(Union[int, None], cursor.fetchone()[0]) or 0 return priority async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult: current_queue_size = self._get_current_queue_size(queue_id) max_queue_size = self.__invoker.services.configuration.max_queue_size max_new_queue_items = max_queue_size - current_queue_size priority = 0 if prepend: priority = self._get_highest_priority(queue_id) + 1 requested_count = await asyncio.to_thread( calc_session_count, batch=batch, ) values_to_insert = await asyncio.to_thread( prepare_values_to_insert, queue_id=queue_id, batch=batch, priority=priority, max_new_queue_items=max_new_queue_items, ) enqueued_count = len(values_to_insert) with self._db.transaction() as cursor: cursor.executemany( """--sql INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, values_to_insert, ) cursor.execute( """--sql SELECT item_id FROM session_queue WHERE batch_id = ? ORDER BY item_id DESC; """, (batch.batch_id,), ) item_ids = [row[0] for row in cursor.fetchall()] enqueue_result = EnqueueBatchResult( queue_id=queue_id, requested=requested_count, enqueued=enqueued_count, batch=batch, priority=priority, item_ids=item_ids, ) self.__invoker.services.events.emit_batch_enqueued(enqueue_result) return enqueue_result def dequeue(self) -> Optional[SessionQueueItem]: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT * FROM session_queue WHERE status = 'pending' ORDER BY priority DESC, item_id ASC LIMIT 1 """ ) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None queue_item = SessionQueueItem.queue_item_from_dict(dict(result)) queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress") return queue_item def get_next(self, queue_id: str) -> Optional[SessionQueueItem]: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT * FROM session_queue WHERE queue_id = ? AND status = 'pending' ORDER BY priority DESC, created_at ASC LIMIT 1 """, (queue_id,), ) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None return SessionQueueItem.queue_item_from_dict(dict(result)) def get_current(self, queue_id: str) -> Optional[SessionQueueItem]: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT * FROM session_queue WHERE queue_id = ? AND status = 'in_progress' LIMIT 1 """, (queue_id,), ) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: return None return SessionQueueItem.queue_item_from_dict(dict(result)) def _set_queue_item_status( self, item_id: int, status: QUEUE_ITEM_STATUS, error_type: Optional[str] = None, error_message: Optional[str] = None, error_traceback: Optional[str] = None, ) -> SessionQueueItem: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT status FROM session_queue WHERE item_id = ? """, (item_id,), ) row = cursor.fetchone() if row is None: raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") current_status = row[0] # Only update if not already finished (completed, failed or canceled) if current_status in ("completed", "failed", "canceled"): return self.get_queue_item(item_id) with self._db.transaction() as cursor: cursor.execute( """--sql UPDATE session_queue SET status = ?, error_type = ?, error_message = ?, error_traceback = ? WHERE item_id = ? """, (status, error_type, error_message, error_traceback, item_id), ) queue_item = self.get_queue_item(item_id) batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id) queue_status = self.get_queue_status(queue_id=queue_item.queue_id) self.__invoker.services.events.emit_queue_item_status_changed(queue_item, batch_status, queue_status) return queue_item def is_empty(self, queue_id: str) -> IsEmptyResult: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT count(*) FROM session_queue WHERE queue_id = ? """, (queue_id,), ) is_empty = cast(int, cursor.fetchone()[0]) == 0 return IsEmptyResult(is_empty=is_empty) def is_full(self, queue_id: str) -> IsFullResult: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT count(*) FROM session_queue WHERE queue_id = ? """, (queue_id,), ) max_queue_size = self.__invoker.services.configuration.max_queue_size is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size return IsFullResult(is_full=is_full) def clear(self, queue_id: str) -> ClearResult: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT COUNT(*) FROM session_queue WHERE queue_id = ? """, (queue_id,), ) count = cursor.fetchone()[0] cursor.execute( """--sql DELETE FROM session_queue WHERE queue_id = ? """, (queue_id,), ) self.__invoker.services.events.emit_queue_cleared(queue_id) return ClearResult(deleted=count) def prune(self, queue_id: str) -> PruneResult: with self._db.transaction() as cursor: where = """--sql WHERE queue_id = ? AND ( status = 'completed' OR status = 'failed' OR status = 'canceled' ) """ cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, (queue_id,), ) count = cursor.fetchone()[0] cursor.execute( f"""--sql DELETE FROM session_queue {where}; """, (queue_id,), ) return PruneResult(deleted=count) def cancel_queue_item(self, item_id: int) -> SessionQueueItem: queue_item = self._set_queue_item_status(item_id=item_id, status="canceled") return queue_item def delete_queue_item(self, item_id: int) -> None: """Deletes a session queue item""" try: self.cancel_queue_item(item_id) except SessionQueueItemNotFoundError: pass with self._db.transaction() as cursor: cursor.execute( """--sql DELETE FROM session_queue WHERE item_id = ? """, (item_id,), ) def complete_queue_item(self, item_id: int) -> SessionQueueItem: queue_item = self._set_queue_item_status(item_id=item_id, status="completed") return queue_item def fail_queue_item( self, item_id: int, error_type: str, error_message: str, error_traceback: str, ) -> SessionQueueItem: queue_item = self._set_queue_item_status( item_id=item_id, status="failed", error_type=error_type, error_message=error_message, error_traceback=error_traceback, ) return queue_item def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult: with self._db.transaction() as cursor: current_queue_item = self.get_current(queue_id) placeholders = ", ".join(["?" for _ in batch_ids]) where = f"""--sql WHERE queue_id == ? AND batch_id IN ({placeholders}) AND status != 'canceled' AND status != 'completed' AND status != 'failed' -- We will cancel the current item separately below - skip it here AND status != 'in_progress' """ params = [queue_id] + batch_ids cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, tuple(params), ) count = cursor.fetchone()[0] cursor.execute( f"""--sql UPDATE session_queue SET status = 'canceled' {where}; """, tuple(params), ) if current_queue_item is not None and current_queue_item.batch_id in batch_ids: self._set_queue_item_status(current_queue_item.item_id, "canceled") return CancelByBatchIDsResult(canceled=count) def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult: with self._db.transaction() as cursor: current_queue_item = self.get_current(queue_id) where = """--sql WHERE queue_id == ? AND destination == ? AND status != 'canceled' AND status != 'completed' AND status != 'failed' -- We will cancel the current item separately below - skip it here AND status != 'in_progress' """ params = (queue_id, destination) cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, params, ) count = cursor.fetchone()[0] cursor.execute( f"""--sql UPDATE session_queue SET status = 'canceled' {where}; """, params, ) if current_queue_item is not None and current_queue_item.destination == destination: self._set_queue_item_status(current_queue_item.item_id, "canceled") return CancelByDestinationResult(canceled=count) def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult: with self._db.transaction() as cursor: current_queue_item = self.get_current(queue_id) if current_queue_item is not None and current_queue_item.destination == destination: self.cancel_queue_item(current_queue_item.item_id) params = (queue_id, destination) cursor.execute( """--sql SELECT COUNT(*) FROM session_queue WHERE queue_id = ? AND destination = ?; """, params, ) count = cursor.fetchone()[0] cursor.execute( """--sql DELETE FROM session_queue WHERE queue_id = ? AND destination = ?; """, params, ) return DeleteByDestinationResult(deleted=count) def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult: with self._db.transaction() as cursor: where = """--sql WHERE queue_id == ? AND status == 'pending' """ cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, (queue_id,), ) count = cursor.fetchone()[0] cursor.execute( f"""--sql DELETE FROM session_queue {where}; """, (queue_id,), ) return DeleteAllExceptCurrentResult(deleted=count) def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult: with self._db.transaction() as cursor: current_queue_item = self.get_current(queue_id) where = """--sql WHERE queue_id is ? AND status != 'canceled' AND status != 'completed' AND status != 'failed' -- We will cancel the current item separately below - skip it here AND status != 'in_progress' """ params = [queue_id] cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, tuple(params), ) count = cursor.fetchone()[0] cursor.execute( f"""--sql UPDATE session_queue SET status = 'canceled' {where}; """, tuple(params), ) if current_queue_item is not None and current_queue_item.queue_id == queue_id: self._set_queue_item_status(current_queue_item.item_id, "canceled") return CancelByQueueIDResult(canceled=count) def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult: with self._db.transaction() as cursor: where = """--sql WHERE queue_id == ? AND status == 'pending' """ cursor.execute( f"""--sql SELECT COUNT(*) FROM session_queue {where}; """, (queue_id,), ) count = cursor.fetchone()[0] cursor.execute( f"""--sql UPDATE session_queue SET status = 'canceled' {where}; """, (queue_id,), ) return CancelAllExceptCurrentResult(canceled=count) def get_queue_item(self, item_id: int) -> SessionQueueItem: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT * FROM session_queue WHERE item_id = ? """, (item_id,), ) result = cast(Union[sqlite3.Row, None], cursor.fetchone()) if result is None: raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}") return SessionQueueItem.queue_item_from_dict(dict(result)) def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem: with self._db.transaction() as cursor: # Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors # when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced # during execution. session_json = session.model_dump_json(warnings=False, exclude_none=True) cursor.execute( """--sql UPDATE session_queue SET session = ? WHERE item_id = ? """, (session_json, item_id), ) return self.get_queue_item(item_id) def list_queue_items( self, queue_id: str, limit: int, priority: int, cursor: Optional[int] = None, status: Optional[QUEUE_ITEM_STATUS] = None, destination: Optional[str] = None, ) -> CursorPaginatedResults[SessionQueueItem]: with self._db.transaction() as cursor_: item_id = cursor query = """--sql SELECT * FROM session_queue WHERE queue_id = ? """ params: list[Union[str, int]] = [queue_id] if status is not None: query += """--sql AND status = ? """ params.append(status) if destination is not None: query += """---sql AND destination = ? """ params.append(destination) if item_id is not None: query += """--sql AND (priority < ?) OR (priority = ? AND item_id > ?) """ params.extend([priority, priority, item_id]) query += """--sql ORDER BY priority DESC, item_id ASC LIMIT ? """ params.append(limit + 1) cursor_.execute(query, params) results = cast(list[sqlite3.Row], cursor_.fetchall()) items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results] has_more = False if len(items) > limit: # remove the extra item items.pop() has_more = True return CursorPaginatedResults(items=items, limit=limit, has_more=has_more) def list_all_queue_items( self, queue_id: str, destination: Optional[str] = None, ) -> list[SessionQueueItem]: """Gets all queue items that match the given parameters""" with self._db.transaction() as cursor: query = """--sql SELECT * FROM session_queue WHERE queue_id = ? """ params: list[Union[str, int]] = [queue_id] if destination is not None: query += """---sql AND destination = ? """ params.append(destination) query += """--sql ORDER BY priority DESC, item_id ASC ; """ cursor.execute(query, params) results = cast(list[sqlite3.Row], cursor.fetchall()) items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results] return items def get_queue_item_ids( self, queue_id: str, order_dir: SQLiteDirection = SQLiteDirection.Descending, ) -> ItemIdsResult: with self._db.transaction() as cursor_: query = f"""--sql SELECT item_id FROM session_queue WHERE queue_id = ? ORDER BY created_at {order_dir.value} """ query_params = [queue_id] cursor_.execute(query, query_params) result = cast(list[sqlite3.Row], cursor_.fetchall()) item_ids = [row[0] for row in result] return ItemIdsResult(item_ids=item_ids, total_count=len(item_ids)) def get_queue_status(self, queue_id: str) -> SessionQueueStatus: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT status, count(*) FROM session_queue WHERE queue_id = ? GROUP BY status """, (queue_id,), ) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) current_item = self.get_current(queue_id=queue_id) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} return SessionQueueStatus( queue_id=queue_id, item_id=current_item.item_id if current_item else None, session_id=current_item.session_id if current_item else None, batch_id=current_item.batch_id if current_item else None, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, ) def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT status, count(*), origin, destination FROM session_queue WHERE queue_id = ? AND batch_id = ? GROUP BY status """, (queue_id, batch_id), ) result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in result) counts: dict[str, int] = {row[0]: row[1] for row in result} origin = result[0]["origin"] if result else None destination = result[0]["destination"] if result else None return BatchStatus( batch_id=batch_id, origin=origin, destination=destination, queue_id=queue_id, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, ) def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination: with self._db.transaction() as cursor: cursor.execute( """--sql SELECT status, count(*) FROM session_queue WHERE queue_id = ? AND destination = ? GROUP BY status """, (queue_id, destination), ) counts_result = cast(list[sqlite3.Row], cursor.fetchall()) total = sum(row[1] or 0 for row in counts_result) counts: dict[str, int] = {row[0]: row[1] for row in counts_result} return SessionQueueCountsByDestination( queue_id=queue_id, destination=destination, pending=counts.get("pending", 0), in_progress=counts.get("in_progress", 0), completed=counts.get("completed", 0), failed=counts.get("failed", 0), canceled=counts.get("canceled", 0), total=total, ) def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult: """Retries the given queue items""" with self._db.transaction() as cursor: values_to_insert: list[ValueToInsertTuple] = [] retried_item_ids: list[int] = [] for item_id in item_ids: queue_item = self.get_queue_item(item_id) if queue_item.status not in ("failed", "canceled"): continue retried_item_ids.append(item_id) field_values_json = ( json.dumps(queue_item.field_values, default=to_jsonable_python) if queue_item.field_values else None ) workflow_json = ( json.dumps(queue_item.workflow, default=to_jsonable_python) if queue_item.workflow else None ) cloned_session = GraphExecutionState(graph=queue_item.session.graph) cloned_session_json = cloned_session.model_dump_json(warnings=False, exclude_none=True) retried_from_item_id = ( queue_item.retried_from_item_id if queue_item.retried_from_item_id is not None else queue_item.item_id ) value_to_insert: ValueToInsertTuple = ( queue_item.queue_id, cloned_session_json, cloned_session.id, queue_item.batch_id, field_values_json, queue_item.priority, workflow_json, queue_item.origin, queue_item.destination, retried_from_item_id, ) values_to_insert.append(value_to_insert) # TODO(psyche): Handle max queue size? cursor.executemany( """--sql INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, values_to_insert, ) retry_result = RetryItemsResult( queue_id=queue_id, retried_item_ids=retried_item_ids, ) self.__invoker.services.events.emit_queue_items_retried(retry_result) return retry_result