experiment(app): avoid nested cursors in session_queue service

SQLite cursors are meant to be lightweight and not reused. For whatever reason, we reuse one per service for the entire app lifecycle.

This can cause issues where a cursor is used twice at the same time in different transactions.

This experiment makes the session queue use a fresh cursor for each method, hopefully fixing the issue.
This commit is contained in:
psychedelicious
2025-03-03 17:28:15 +10:00
parent feee4c49a2
commit e57f0ff055

View File

@@ -37,8 +37,6 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteSessionQueue(SessionQueueBase):
__invoker: Invoker
__conn: sqlite3.Connection
__cursor: sqlite3.Cursor
def start(self, invoker: Invoker) -> None:
self.__invoker = invoker
@@ -54,8 +52,7 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
self.__conn = db.conn
self.__cursor = self.__conn.cursor()
self._conn = db.conn
def _set_in_progress_to_canceled(self) -> None:
"""
@@ -63,7 +60,8 @@ class SqliteSessionQueue(SessionQueueBase):
This is necessary because the invoker may have been killed while processing a queue item.
"""
try:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -71,12 +69,13 @@ class SqliteSessionQueue(SessionQueueBase):
"""
)
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -86,11 +85,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(int, self.__cursor.fetchone()[0])
return cast(int, cursor.fetchone()[0])
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT MAX(priority)
FROM session_queue
@@ -100,13 +100,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
return cast(Union[int, None], cursor.fetchone()[0]) or 0
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
return await asyncio.to_thread(self._enqueue_batch, queue_id, batch, prepend)
def _enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
try:
cursor = self._conn.cursor()
# TODO: how does this work in a multi-user scenario?
current_queue_size = self._get_current_queue_size(queue_id)
max_queue_size = self.__invoker.services.configuration.max_queue_size
@@ -128,16 +129,16 @@ class SqliteSessionQueue(SessionQueueBase):
if requested_count > enqueued_count:
values_to_insert = values_to_insert[:max_new_queue_items]
self.__cursor.executemany(
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
values_to_insert,
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
@@ -150,7 +151,8 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -161,7 +163,7 @@ class SqliteSessionQueue(SessionQueueBase):
LIMIT 1
"""
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -169,7 +171,8 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -183,13 +186,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT *
FROM session_queue
@@ -200,7 +204,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -214,7 +218,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
try:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
SET status = ?, error_type = ?, error_message = ?, error_traceback = ?
@@ -222,9 +227,9 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
@@ -233,7 +238,8 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -241,11 +247,12 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT count(*)
FROM session_queue
@@ -254,12 +261,13 @@ class SqliteSessionQueue(SessionQueueBase):
(queue_id,),
)
max_queue_size = self.__invoker.services.configuration.max_queue_size
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
try:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -267,8 +275,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
"""--sql
DELETE
FROM session_queue
@@ -276,15 +284,16 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id = ?
@@ -294,7 +303,7 @@ class SqliteSessionQueue(SessionQueueBase):
OR status = 'canceled'
)
"""
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -302,8 +311,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
DELETE
FROM session_queue
@@ -311,9 +320,9 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return PruneResult(deleted=count)
@@ -343,6 +352,7 @@ class SqliteSessionQueue(SessionQueueBase):
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -354,7 +364,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id] + batch_ids
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -362,8 +372,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -371,16 +381,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -391,7 +402,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = (queue_id, destination)
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -399,8 +410,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -408,16 +419,17 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.destination == destination:
self._set_queue_item_status(current_queue_item.item_id, "canceled")
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return CancelByDestinationResult(canceled=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
try:
cursor = self._conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -427,7 +439,7 @@ class SqliteSessionQueue(SessionQueueBase):
AND status != 'failed'
"""
params = [queue_id]
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -435,8 +447,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -444,7 +456,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
self.__conn.commit()
self._conn.commit()
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_id)
@@ -452,18 +464,19 @@ class SqliteSessionQueue(SessionQueueBase):
current_queue_item, batch_status, queue_status
)
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
try:
cursor = self._conn.cursor()
where = """--sql
WHERE
queue_id == ?
AND status == 'pending'
"""
self.__cursor.execute(
cursor.execute(
f"""--sql
SELECT COUNT(*)
FROM session_queue
@@ -471,8 +484,8 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
count = self.__cursor.fetchone()[0]
self.__cursor.execute(
count = cursor.fetchone()[0]
cursor.execute(
f"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -480,14 +493,15 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT * FROM session_queue
WHERE
@@ -495,18 +509,19 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(item_id,),
)
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
try:
cursor = self._conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
session_json = session.model_dump_json(warnings=False, exclude_none=True)
self.__cursor.execute(
cursor.execute(
"""--sql
UPDATE session_queue
SET session = ?
@@ -514,9 +529,9 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
return self.get_queue_item(item_id)
@@ -528,6 +543,7 @@ class SqliteSessionQueue(SessionQueueBase):
cursor: Optional[int] = None,
status: Optional[QUEUE_ITEM_STATUS] = None,
) -> CursorPaginatedResults[SessionQueueItemDTO]:
cursor_ = self._conn.cursor()
item_id = cursor
query = """--sql
SELECT item_id,
@@ -570,8 +586,8 @@ class SqliteSessionQueue(SessionQueueBase):
LIMIT ?
"""
params.append(limit + 1)
self.__cursor.execute(query, params)
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
cursor_.execute(query, params)
results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -581,7 +597,8 @@ class SqliteSessionQueue(SessionQueueBase):
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
@@ -590,7 +607,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] for row in counts_result)
@@ -609,7 +626,8 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
FROM session_queue
@@ -620,7 +638,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id, batch_id),
)
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
@@ -640,7 +658,8 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
self.__cursor.execute(
cursor = self._conn.cursor()
cursor.execute(
"""--sql
SELECT status, count(*)
FROM session_queue
@@ -650,7 +669,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id, destination),
)
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -669,6 +688,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
try:
cursor = self._conn.cursor()
values_to_insert: list[tuple] = []
retried_item_ids: list[int] = []
@@ -711,7 +731,7 @@ class SqliteSessionQueue(SessionQueueBase):
# TODO(psyche): Handle max queue size?
self.__cursor.executemany(
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
@@ -719,9 +739,9 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
self.__conn.commit()
self._conn.commit()
except Exception:
self.__conn.rollback()
self._conn.rollback()
raise
retry_result = RetryItemsResult(
queue_id=queue_id,