mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-20 05:18:17 -05:00
696 lines
24 KiB
Python
696 lines
24 KiB
Python
import sqlite3
|
|
import threading
|
|
from typing import Optional, Union, cast
|
|
|
|
from fastapi_events.handlers.local import local_handler
|
|
from fastapi_events.typing import Event as FastAPIEvent
|
|
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
from invokeai.app.services.invoker import Invoker
|
|
from invokeai.app.services.session_queue.session_queue_base import SessionQueueBase
|
|
from invokeai.app.services.session_queue.session_queue_common import (
|
|
DEFAULT_QUEUE_ID,
|
|
QUEUE_ITEM_STATUS,
|
|
Batch,
|
|
BatchStatus,
|
|
CancelByBatchIDsResult,
|
|
CancelByQueueIDResult,
|
|
ClearResult,
|
|
EnqueueBatchResult,
|
|
IsEmptyResult,
|
|
IsFullResult,
|
|
PruneResult,
|
|
SessionQueueItem,
|
|
SessionQueueItemDTO,
|
|
SessionQueueItemNotFoundError,
|
|
SessionQueueStatus,
|
|
calc_session_count,
|
|
prepare_values_to_insert,
|
|
)
|
|
from invokeai.app.services.shared.pagination import CursorPaginatedResults
|
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
|
|
|
|
|
class SqliteSessionQueue(SessionQueueBase):
|
|
__invoker: Invoker
|
|
__conn: sqlite3.Connection
|
|
__cursor: sqlite3.Cursor
|
|
__lock: threading.RLock
|
|
|
|
def start(self, invoker: Invoker) -> None:
|
|
self.__invoker = invoker
|
|
self._set_in_progress_to_canceled()
|
|
prune_result = self.prune(DEFAULT_QUEUE_ID)
|
|
local_handler.register(event_name=EventServiceBase.queue_event, _func=self._on_session_event)
|
|
if prune_result.deleted > 0:
|
|
self.__invoker.services.logger.info(f"Pruned {prune_result.deleted} finished queue items")
|
|
|
|
def __init__(self, db: SqliteDatabase) -> None:
|
|
super().__init__()
|
|
self.__lock = db.lock
|
|
self.__conn = db.conn
|
|
self.__cursor = self.__conn.cursor()
|
|
|
|
def _match_event_name(self, event: FastAPIEvent, match_in: list[str]) -> bool:
|
|
return event[1]["event"] in match_in
|
|
|
|
async def _on_session_event(self, event: FastAPIEvent) -> FastAPIEvent:
|
|
event_name = event[1]["event"]
|
|
|
|
# This was a match statement, but match is not supported on python 3.9
|
|
if event_name == "graph_execution_state_complete":
|
|
await self._handle_complete_event(event)
|
|
elif event_name == "invocation_error":
|
|
await self._handle_error_event(event)
|
|
elif event_name == "session_canceled":
|
|
await self._handle_cancel_event(event)
|
|
return event
|
|
|
|
async def _handle_complete_event(self, event: FastAPIEvent) -> None:
|
|
try:
|
|
item_id = event[1]["data"]["queue_item_id"]
|
|
# When a queue item has an error, we get an error event, then a completed event.
|
|
# Mark the queue item completed only if it isn't already marked completed, e.g.
|
|
# by a previously-handled error event.
|
|
queue_item = self.get_queue_item(item_id)
|
|
if queue_item.status not in ["completed", "failed", "canceled"]:
|
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="completed")
|
|
except SessionQueueItemNotFoundError:
|
|
return
|
|
|
|
async def _handle_error_event(self, event: FastAPIEvent) -> None:
|
|
try:
|
|
item_id = event[1]["data"]["queue_item_id"]
|
|
error = event[1]["data"]["error"]
|
|
queue_item = self.get_queue_item(item_id)
|
|
# always set to failed if have an error, even if previously the item was marked completed or canceled
|
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="failed", error=error)
|
|
except SessionQueueItemNotFoundError:
|
|
return
|
|
|
|
async def _handle_cancel_event(self, event: FastAPIEvent) -> None:
|
|
try:
|
|
item_id = event[1]["data"]["queue_item_id"]
|
|
queue_item = self.get_queue_item(item_id)
|
|
if queue_item.status not in ["completed", "failed", "canceled"]:
|
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="canceled")
|
|
except SessionQueueItemNotFoundError:
|
|
return
|
|
|
|
def _set_in_progress_to_canceled(self) -> None:
|
|
"""
|
|
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
|
|
This is necessary because the invoker may have been killed while processing a queue item.
|
|
"""
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
UPDATE session_queue
|
|
SET status = 'canceled'
|
|
WHERE status = 'in_progress';
|
|
"""
|
|
)
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
|
|
def _get_current_queue_size(self, queue_id: str) -> int:
|
|
"""Gets the current number of pending queue items"""
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT count(*)
|
|
FROM session_queue
|
|
WHERE
|
|
queue_id = ?
|
|
AND status = 'pending'
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
return cast(int, self.__cursor.fetchone()[0])
|
|
|
|
def _get_highest_priority(self, queue_id: str) -> int:
|
|
"""Gets the highest priority value in the queue"""
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT MAX(priority)
|
|
FROM session_queue
|
|
WHERE
|
|
queue_id = ?
|
|
AND status = 'pending'
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
return cast(Union[int, None], self.__cursor.fetchone()[0]) or 0
|
|
|
|
def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
|
|
try:
|
|
self.__lock.acquire()
|
|
|
|
# TODO: how does this work in a multi-user scenario?
|
|
current_queue_size = self._get_current_queue_size(queue_id)
|
|
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
|
max_new_queue_items = max_queue_size - current_queue_size
|
|
|
|
priority = 0
|
|
if prepend:
|
|
priority = self._get_highest_priority(queue_id) + 1
|
|
|
|
requested_count = calc_session_count(batch)
|
|
values_to_insert = prepare_values_to_insert(
|
|
queue_id=queue_id,
|
|
batch=batch,
|
|
priority=priority,
|
|
max_new_queue_items=max_new_queue_items,
|
|
)
|
|
enqueued_count = len(values_to_insert)
|
|
|
|
if requested_count > enqueued_count:
|
|
values_to_insert = values_to_insert[:max_new_queue_items]
|
|
|
|
self.__cursor.executemany(
|
|
"""--sql
|
|
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
values_to_insert,
|
|
)
|
|
self.__conn.commit()
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
enqueue_result = EnqueueBatchResult(
|
|
queue_id=queue_id,
|
|
requested=requested_count,
|
|
enqueued=enqueued_count,
|
|
batch=batch,
|
|
priority=priority,
|
|
)
|
|
self.__invoker.services.events.emit_batch_enqueued(enqueue_result)
|
|
return enqueue_result
|
|
|
|
def dequeue(self) -> Optional[SessionQueueItem]:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM session_queue
|
|
WHERE status = 'pending'
|
|
ORDER BY
|
|
priority DESC,
|
|
item_id ASC
|
|
LIMIT 1
|
|
"""
|
|
)
|
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
if result is None:
|
|
return None
|
|
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
|
|
queue_item = self._set_queue_item_status(item_id=queue_item.item_id, status="in_progress")
|
|
return queue_item
|
|
|
|
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM session_queue
|
|
WHERE
|
|
queue_id = ?
|
|
AND status = 'pending'
|
|
ORDER BY
|
|
priority DESC,
|
|
created_at ASC
|
|
LIMIT 1
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
if result is None:
|
|
return None
|
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
|
|
|
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT *
|
|
FROM session_queue
|
|
WHERE
|
|
queue_id = ?
|
|
AND status = 'in_progress'
|
|
LIMIT 1
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
if result is None:
|
|
return None
|
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
|
|
|
def _set_queue_item_status(
|
|
self, item_id: int, status: QUEUE_ITEM_STATUS, error: Optional[str] = None
|
|
) -> SessionQueueItem:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
UPDATE session_queue
|
|
SET status = ?, error = ?
|
|
WHERE item_id = ?
|
|
""",
|
|
(status, error, item_id),
|
|
)
|
|
self.__conn.commit()
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
queue_item = self.get_queue_item(item_id)
|
|
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
|
|
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
|
|
self.__invoker.services.events.emit_queue_item_status_changed(
|
|
session_queue_item=queue_item,
|
|
batch_status=batch_status,
|
|
queue_status=queue_status,
|
|
)
|
|
return queue_item
|
|
|
|
def is_empty(self, queue_id: str) -> IsEmptyResult:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT count(*)
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
is_empty = cast(int, self.__cursor.fetchone()[0]) == 0
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return IsEmptyResult(is_empty=is_empty)
|
|
|
|
def is_full(self, queue_id: str) -> IsFullResult:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT count(*)
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
max_queue_size = self.__invoker.services.configuration.max_queue_size
|
|
is_full = cast(int, self.__cursor.fetchone()[0]) >= max_queue_size
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return IsFullResult(is_full=is_full)
|
|
|
|
def delete_queue_item(self, item_id: int) -> SessionQueueItem:
|
|
queue_item = self.get_queue_item(item_id=item_id)
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
DELETE FROM session_queue
|
|
WHERE
|
|
item_id = ?
|
|
""",
|
|
(item_id,),
|
|
)
|
|
self.__conn.commit()
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return queue_item
|
|
|
|
def clear(self, queue_id: str) -> ClearResult:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT COUNT(*)
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
count = self.__cursor.fetchone()[0]
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
DELETE
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
self.__conn.commit()
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
self.__invoker.services.events.emit_queue_cleared(queue_id)
|
|
return ClearResult(deleted=count)
|
|
|
|
def prune(self, queue_id: str) -> PruneResult:
|
|
try:
|
|
where = """--sql
|
|
WHERE
|
|
queue_id = ?
|
|
AND (
|
|
status = 'completed'
|
|
OR status = 'failed'
|
|
OR status = 'canceled'
|
|
)
|
|
"""
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
SELECT COUNT(*)
|
|
FROM session_queue
|
|
{where};
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
count = self.__cursor.fetchone()[0]
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
DELETE
|
|
FROM session_queue
|
|
{where};
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
self.__conn.commit()
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return PruneResult(deleted=count)
|
|
|
|
def cancel_queue_item(self, item_id: int, error: Optional[str] = None) -> SessionQueueItem:
|
|
queue_item = self.get_queue_item(item_id)
|
|
if queue_item.status not in ["canceled", "failed", "completed"]:
|
|
status = "failed" if error is not None else "canceled"
|
|
queue_item = self._set_queue_item_status(item_id=item_id, status=status, error=error) # type: ignore [arg-type] # mypy seems to not narrow the Literals here
|
|
self.__invoker.services.events.emit_session_canceled(
|
|
queue_item_id=queue_item.item_id,
|
|
queue_id=queue_item.queue_id,
|
|
queue_batch_id=queue_item.batch_id,
|
|
graph_execution_state_id=queue_item.session_id,
|
|
)
|
|
return queue_item
|
|
|
|
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
|
|
try:
|
|
current_queue_item = self.get_current(queue_id)
|
|
self.__lock.acquire()
|
|
placeholders = ", ".join(["?" for _ in batch_ids])
|
|
where = f"""--sql
|
|
WHERE
|
|
queue_id == ?
|
|
AND batch_id IN ({placeholders})
|
|
AND status != 'canceled'
|
|
AND status != 'completed'
|
|
AND status != 'failed'
|
|
"""
|
|
params = [queue_id] + batch_ids
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
SELECT COUNT(*)
|
|
FROM session_queue
|
|
{where};
|
|
""",
|
|
tuple(params),
|
|
)
|
|
count = self.__cursor.fetchone()[0]
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
UPDATE session_queue
|
|
SET status = 'canceled'
|
|
{where};
|
|
""",
|
|
tuple(params),
|
|
)
|
|
self.__conn.commit()
|
|
if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
|
|
self.__invoker.services.events.emit_session_canceled(
|
|
queue_item_id=current_queue_item.item_id,
|
|
queue_id=current_queue_item.queue_id,
|
|
queue_batch_id=current_queue_item.batch_id,
|
|
graph_execution_state_id=current_queue_item.session_id,
|
|
)
|
|
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
|
queue_status = self.get_queue_status(queue_id=queue_id)
|
|
self.__invoker.services.events.emit_queue_item_status_changed(
|
|
session_queue_item=current_queue_item,
|
|
batch_status=batch_status,
|
|
queue_status=queue_status,
|
|
)
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return CancelByBatchIDsResult(canceled=count)
|
|
|
|
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
|
|
try:
|
|
current_queue_item = self.get_current(queue_id)
|
|
self.__lock.acquire()
|
|
where = """--sql
|
|
WHERE
|
|
queue_id is ?
|
|
AND status != 'canceled'
|
|
AND status != 'completed'
|
|
AND status != 'failed'
|
|
"""
|
|
params = [queue_id]
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
SELECT COUNT(*)
|
|
FROM session_queue
|
|
{where};
|
|
""",
|
|
tuple(params),
|
|
)
|
|
count = self.__cursor.fetchone()[0]
|
|
self.__cursor.execute(
|
|
f"""--sql
|
|
UPDATE session_queue
|
|
SET status = 'canceled'
|
|
{where};
|
|
""",
|
|
tuple(params),
|
|
)
|
|
self.__conn.commit()
|
|
if current_queue_item is not None and current_queue_item.queue_id == queue_id:
|
|
self.__invoker.services.events.emit_session_canceled(
|
|
queue_item_id=current_queue_item.item_id,
|
|
queue_id=current_queue_item.queue_id,
|
|
queue_batch_id=current_queue_item.batch_id,
|
|
graph_execution_state_id=current_queue_item.session_id,
|
|
)
|
|
batch_status = self.get_batch_status(queue_id=queue_id, batch_id=current_queue_item.batch_id)
|
|
queue_status = self.get_queue_status(queue_id=queue_id)
|
|
self.__invoker.services.events.emit_queue_item_status_changed(
|
|
session_queue_item=current_queue_item,
|
|
batch_status=batch_status,
|
|
queue_status=queue_status,
|
|
)
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return CancelByQueueIDResult(canceled=count)
|
|
|
|
def get_queue_item(self, item_id: int) -> SessionQueueItem:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT * FROM session_queue
|
|
WHERE
|
|
item_id = ?
|
|
""",
|
|
(item_id,),
|
|
)
|
|
result = cast(Union[sqlite3.Row, None], self.__cursor.fetchone())
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
if result is None:
|
|
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
|
|
return SessionQueueItem.queue_item_from_dict(dict(result))
|
|
|
|
def list_queue_items(
|
|
self,
|
|
queue_id: str,
|
|
limit: int,
|
|
priority: int,
|
|
cursor: Optional[int] = None,
|
|
status: Optional[QUEUE_ITEM_STATUS] = None,
|
|
) -> CursorPaginatedResults[SessionQueueItemDTO]:
|
|
try:
|
|
item_id = cursor
|
|
self.__lock.acquire()
|
|
query = """--sql
|
|
SELECT item_id,
|
|
status,
|
|
priority,
|
|
field_values,
|
|
error,
|
|
created_at,
|
|
updated_at,
|
|
completed_at,
|
|
started_at,
|
|
session_id,
|
|
batch_id,
|
|
queue_id
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
"""
|
|
params: list[Union[str, int]] = [queue_id]
|
|
|
|
if status is not None:
|
|
query += """--sql
|
|
AND status = ?
|
|
"""
|
|
params.append(status)
|
|
|
|
if item_id is not None:
|
|
query += """--sql
|
|
AND (priority < ?) OR (priority = ? AND item_id > ?)
|
|
"""
|
|
params.extend([priority, priority, item_id])
|
|
|
|
query += """--sql
|
|
ORDER BY
|
|
priority DESC,
|
|
item_id ASC
|
|
LIMIT ?
|
|
"""
|
|
params.append(limit + 1)
|
|
self.__cursor.execute(query, params)
|
|
results = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
|
items = [SessionQueueItemDTO.queue_item_dto_from_dict(dict(result)) for result in results]
|
|
has_more = False
|
|
if len(items) > limit:
|
|
# remove the extra item
|
|
items.pop()
|
|
has_more = True
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
return CursorPaginatedResults(items=items, limit=limit, has_more=has_more)
|
|
|
|
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT status, count(*)
|
|
FROM session_queue
|
|
WHERE queue_id = ?
|
|
GROUP BY status
|
|
""",
|
|
(queue_id,),
|
|
)
|
|
counts_result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
|
|
current_item = self.get_current(queue_id=queue_id)
|
|
total = sum(row[1] for row in counts_result)
|
|
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
|
|
return SessionQueueStatus(
|
|
queue_id=queue_id,
|
|
item_id=current_item.item_id if current_item else None,
|
|
session_id=current_item.session_id if current_item else None,
|
|
batch_id=current_item.batch_id if current_item else None,
|
|
pending=counts.get("pending", 0),
|
|
in_progress=counts.get("in_progress", 0),
|
|
completed=counts.get("completed", 0),
|
|
failed=counts.get("failed", 0),
|
|
canceled=counts.get("canceled", 0),
|
|
total=total,
|
|
)
|
|
|
|
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
|
|
try:
|
|
self.__lock.acquire()
|
|
self.__cursor.execute(
|
|
"""--sql
|
|
SELECT status, count(*)
|
|
FROM session_queue
|
|
WHERE
|
|
queue_id = ?
|
|
AND batch_id = ?
|
|
GROUP BY status
|
|
""",
|
|
(queue_id, batch_id),
|
|
)
|
|
result = cast(list[sqlite3.Row], self.__cursor.fetchall())
|
|
total = sum(row[1] for row in result)
|
|
counts: dict[str, int] = {row[0]: row[1] for row in result}
|
|
except Exception:
|
|
self.__conn.rollback()
|
|
raise
|
|
finally:
|
|
self.__lock.release()
|
|
|
|
return BatchStatus(
|
|
batch_id=batch_id,
|
|
queue_id=queue_id,
|
|
pending=counts.get("pending", 0),
|
|
in_progress=counts.get("in_progress", 0),
|
|
completed=counts.get("completed", 0),
|
|
failed=counts.get("failed", 0),
|
|
canceled=counts.get("canceled", 0),
|
|
total=total,
|
|
)
|