From ef4d5d73775759ed5c5bb955dc8cd746c915f2c0 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 10 Jul 2025 20:04:07 +1000
Subject: [PATCH 01/17] feat(ui): virtualized list for staging area
Make the staging area a virtualized list so it doesn't choke when there
are a large number (i.e. more than a few hundred) of queue items.
---
.../SimpleSession/QueueItemPreviewMini.tsx | 9 +-
.../SimpleSession/StagingAreaItemsList.tsx | 208 ++++++++++++++++--
.../components/SimpleSession/shared.ts | 2 +-
.../StagingArea/StagingAreaToolbar.tsx | 13 +-
.../gallery/components/NewGallery.tsx | 9 +-
.../src/features/ui/layouts/StagingArea.tsx | 2 +-
6 files changed, 203 insertions(+), 40 deletions(-)
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
index 0c0d70f060..6ea9a52eb0 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
@@ -27,6 +27,7 @@ const sx = {
alignItems: 'center',
justifyContent: 'center',
flexShrink: 0,
+ h: 'full',
aspectRatio: '1/1',
borderWidth: 2,
borderRadius: 'base',
@@ -39,11 +40,11 @@ const sx = {
type Props = {
item: S['SessionQueueItem'];
- number: number;
+ index: number;
isSelected: boolean;
};
-export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) => {
+export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) => {
const dispatch = useAppDispatch();
const ctx = useCanvasSessionContext();
const { imageLoaded } = useProgressData(ctx.$progressData, item.item_id);
@@ -69,7 +70,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, number }: Props) =
return (
{imageDTO && }
{!imageLoaded && }
-
+
);
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
index b4606cd58a..c644608a2e 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
@@ -1,17 +1,148 @@
-import { Flex } from '@invoke-ai/ui-library';
+import { Box, Flex, forwardRef } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
-import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
+import { logger } from 'app/logging/logger';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
import { QueueItemPreviewMini } from 'features/controlLayers/components/SimpleSession/QueueItemPreviewMini';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
-import { memo, useEffect } from 'react';
+import { useOverlayScrollbars } from 'overlayscrollbars-react';
+import type { CSSProperties, RefObject } from 'react';
+import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
+import type { Components, ItemContent, ListRange, VirtuosoHandle, VirtuosoProps } from 'react-virtuoso';
+import { Virtuoso } from 'react-virtuoso';
+import type { S } from 'services/api/types';
+
+import { getQueueItemElementId } from './shared';
+
+const log = logger('system');
+
+const virtuosoStyles = {
+ width: '100%',
+ height: '72px',
+} satisfies CSSProperties;
+
+type VirtuosoContext = { selectedItemId: number | null };
+
+/**
+ * Scroll the item at the given index into view if it is not currently visible.
+ */
+const scrollIntoView = (
+ targetIndex: number,
+ rootEl: HTMLDivElement,
+ virtuosoHandle: VirtuosoHandle,
+ range: ListRange
+) => {
+ if (range.endIndex === 0) {
+ // No range is rendered; no need to scroll to anything.
+ return;
+ }
+
+ const targetItem = rootEl.querySelector(`#${getQueueItemElementId(targetIndex)}`);
+
+ if (!targetItem) {
+ if (targetIndex > range.endIndex) {
+ virtuosoHandle.scrollToIndex({
+ index: targetIndex,
+ behavior: 'auto',
+ align: 'end',
+ });
+ } else if (targetIndex < range.startIndex) {
+ virtuosoHandle.scrollToIndex({
+ index: targetIndex,
+ behavior: 'auto',
+ align: 'start',
+ });
+ } else {
+ log.debug(
+ `Unable to find queue item at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
+ );
+ }
+ return;
+ }
+
+ // We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
+ // Check if it is in the viewport and scroll if necessary.
+
+ const itemRect = targetItem.getBoundingClientRect();
+ const rootRect = rootEl.getBoundingClientRect();
+
+ if (itemRect.left < rootRect.left) {
+ virtuosoHandle.scrollToIndex({
+ index: targetIndex,
+ behavior: 'auto',
+ align: 'start',
+ });
+ } else if (itemRect.right > rootRect.right) {
+ virtuosoHandle.scrollToIndex({
+ index: targetIndex,
+ behavior: 'auto',
+ align: 'end',
+ });
+ } else {
+ // Image is already in view
+ }
+
+ return;
+};
+
+const useScrollableStagingArea = (rootRef: RefObject) => {
+ const [scroller, scrollerRef] = useState(null);
+ const [initialize, osInstance] = useOverlayScrollbars({
+ defer: true,
+ events: {
+ initialized(osInstance) {
+ // force overflow styles
+ const { viewport } = osInstance.elements();
+ viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
+ viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
+ },
+ },
+ options: {
+ scrollbars: {
+ visibility: 'auto',
+ autoHide: 'scroll',
+ autoHideDelay: 1300,
+ theme: 'os-theme-dark',
+ },
+ overflow: {
+ y: 'hidden',
+ x: 'scroll',
+ },
+ },
+ });
+
+ useEffect(() => {
+ const { current: root } = rootRef;
+
+ if (scroller && root) {
+ initialize({
+ target: root,
+ elements: {
+ viewport: scroller,
+ },
+ });
+ }
+
+ return () => {
+ osInstance()?.destroy();
+ };
+ }, [scroller, initialize, osInstance, rootRef]);
+
+ return scrollerRef;
+};
export const StagingAreaItemsList = memo(() => {
const canvasManager = useCanvasManagerSafe();
const ctx = useCanvasSessionContext();
+ const virtuosoRef = useRef(null);
+ const rangeRef = useRef({ startIndex: 0, endIndex: 0 });
+ const rootRef = useRef(null);
+
const items = useStore(ctx.$items);
const selectedItemId = useStore(ctx.$selectedItemId);
+ const context = useMemo(() => ({ selectedItemId }), [selectedItemId]);
+ const scrollerRef = useScrollableStagingArea(rootRef);
+
useEffect(() => {
if (!canvasManager) {
return;
@@ -20,21 +151,64 @@ export const StagingAreaItemsList = memo(() => {
return canvasManager.stagingArea.connectToSession(ctx.$selectedItemId, ctx.$progressData, ctx.$isPending);
}, [canvasManager, ctx.$progressData, ctx.$selectedItemId, ctx.$isPending]);
+ useEffect(() => {
+ return ctx.$selectedItemIndex.listen((index) => {
+ if (!virtuosoRef.current) {
+ return;
+ }
+
+ if (!rootRef.current) {
+ return;
+ }
+
+ if (index === null) {
+ return;
+ }
+
+ scrollIntoView(index, rootRef.current, virtuosoRef.current, rangeRef.current);
+ });
+ }, [ctx.$selectedItemIndex]);
+
+ const onRangeChanged = useCallback((range: ListRange) => {
+ rangeRef.current = range;
+ }, []);
+
return (
-
-
-
- {items.map((item, i) => (
-
- ))}
-
-
-
+
+
+ ref={virtuosoRef}
+ context={context}
+ data={items}
+ horizontalDirection
+ style={virtuosoStyles}
+ itemContent={itemContent}
+ components={components}
+ rangeChanged={onRangeChanged}
+ // Virtuoso expects the ref to be of HTMLElement | null | Window, but overlayscrollbars doesn't allow Window
+ scrollerRef={scrollerRef as VirtuosoProps['scrollerRef']}
+ />
+
);
});
StagingAreaItemsList.displayName = 'StagingAreaItemsList';
+
+const itemContent: ItemContent = (index, item, { selectedItemId }) => (
+
+);
+
+const listSx = {
+ '& > * + *': {
+ pl: 2,
+ },
+};
+
+const components: Components = {
+ List: forwardRef(({ context: _, ...rest }, ref) => {
+ return ;
+ }),
+};
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts
index be7e97d62a..fe98408df5 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts
+++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/shared.ts
@@ -13,7 +13,7 @@ export const getProgressMessage = (data?: S['InvocationProgressEvent'] | null) =
export const DROP_SHADOW = 'drop-shadow(0px 0px 4px rgb(0, 0, 0)) drop-shadow(0px 0px 4px rgba(0, 0, 0, 0.3))';
-export const getQueueItemElementId = (itemId: number) => `queue-item-status-card-${itemId}`;
+export const getQueueItemElementId = (index: number) => `queue-item-preview-${index}`;
export const getOutputImageName = (item: S['SessionQueueItem']) => {
const nodeId = Object.entries(item.session.source_prepared_mapping).find(([nodeId]) =>
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx
index 9a31934480..64df68cf26 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/StagingArea/StagingAreaToolbar.tsx
@@ -1,7 +1,6 @@
import { ButtonGroup, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useCanvasSessionContext } from 'features/controlLayers/components/SimpleSession/context';
-import { getQueueItemElementId } from 'features/controlLayers/components/SimpleSession/shared';
import { StagingAreaToolbarAcceptButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarAcceptButton';
import { StagingAreaToolbarDiscardAllButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardAllButton';
import { StagingAreaToolbarDiscardSelectedButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarDiscardSelectedButton';
@@ -12,7 +11,7 @@ import { StagingAreaToolbarPrevButton } from 'features/controlLayers/components/
import { StagingAreaToolbarSaveSelectedToGalleryButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarSaveSelectedToGalleryButton';
import { StagingAreaToolbarToggleShowResultsButton } from 'features/controlLayers/components/StagingArea/StagingAreaToolbarToggleShowResultsButton';
import { useCanvasManager } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
-import { memo, useEffect } from 'react';
+import { memo } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { StagingAreaAutoSwitchButtons } from './StagingAreaAutoSwitchButtons';
@@ -23,16 +22,6 @@ export const StagingAreaToolbar = memo(() => {
const ctx = useCanvasSessionContext();
- useEffect(() => {
- return ctx.$selectedItemId.listen((id) => {
- if (id !== null) {
- document
- .getElementById(getQueueItemElementId(id))
- ?.scrollIntoView({ block: 'nearest', inline: 'nearest', behavior: 'auto' });
- }
- });
- }, [ctx.$selectedItemId]);
-
useHotkeys('meta+left', ctx.selectFirst, { preventDefault: true });
useHotkeys('meta+right', ctx.selectLast, { preventDefault: true });
diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx
index 326594e65c..affe18118e 100644
--- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx
@@ -482,11 +482,6 @@ export const NewGallery = memo(() => {
const context = useMemo(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
- // Item content function
- const itemContent: GridItemContent = useCallback((index, imageName) => {
- return ;
- }, []);
-
if (isLoading) {
return (
@@ -553,6 +548,10 @@ const ListComponent: GridComponents['List'] = forwardRef(({ context
});
ListComponent.displayName = 'ListComponent';
+const itemContent: GridItemContent = (index, imageName) => {
+ return ;
+};
+
const ItemComponent: GridComponents['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
));
diff --git a/invokeai/frontend/web/src/features/ui/layouts/StagingArea.tsx b/invokeai/frontend/web/src/features/ui/layouts/StagingArea.tsx
index e24e9ac5fe..822e5c3249 100644
--- a/invokeai/frontend/web/src/features/ui/layouts/StagingArea.tsx
+++ b/invokeai/frontend/web/src/features/ui/layouts/StagingArea.tsx
@@ -13,7 +13,7 @@ export const StagingArea = memo(() => {
}
return (
-
+
From a19aa3b0324a0c099e2695e52deaedf5428e0465 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Thu, 10 Jul 2025 18:03:21 +1000
Subject: [PATCH 02/17] feat(app): db abstraction to prevent threading
conflicts
- Add a context manager to the SqliteDatabase class which abstracts away
creating a transaction, committing it on success and rolling back on
error.
- Use it everywhere. The context manager should be exited before
returning results. No business logic changes should be present.
---
.../board_image_records_sqlite.py | 190 +++--
.../board_records/board_records_sqlite.py | 268 +++----
.../image_records/image_records_sqlite.py | 669 +++++++++---------
.../model_records/model_records_sql.py | 319 ++++-----
.../model_relationship_records_sqlite.py | 74 +-
.../session_queue/session_queue_sqlite.py | 592 ++++++++--------
.../services/shared/sqlite/sqlite_database.py | 65 +-
.../sqlite_migrator/sqlite_migrator_impl.py | 12 +-
.../style_preset_records_sqlite.py | 96 +--
.../workflow_records_sqlite.py | 358 +++++-----
tests/test_sqlite_migrator.py | 22 +-
11 files changed, 1296 insertions(+), 1369 deletions(-)
diff --git a/invokeai/app/services/board_image_records/board_image_records_sqlite.py b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
index cc2c9e1379..a6c178097e 100644
--- a/invokeai/app/services/board_image_records/board_image_records_sqlite.py
+++ b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
@@ -14,16 +14,15 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def add_image_to_board(
self,
board_id: str,
image_name: str,
) -> None:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
+ with self._db.conn() as conn:
+ conn.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -31,28 +30,19 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
""",
(board_id, image_name, board_id),
)
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise e
def remove_image_from_board(
self,
image_name: str,
) -> None:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
+ with self._db.conn() as conn:
+ conn.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
""",
(image_name,),
)
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise e
def get_images_for_board(
self,
@@ -60,27 +50,27 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
- # TODO: this isn't paginated yet?
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT images.*
- FROM board_images
- INNER JOIN images ON board_images.image_name = images.image_name
- WHERE board_images.board_id = ?
- ORDER BY board_images.updated_at DESC;
- """,
- (board_id,),
- )
- result = cast(list[sqlite3.Row], cursor.fetchall())
- images = [deserialize_image_record(dict(r)) for r in result]
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT images.*
+ FROM board_images
+ INNER JOIN images ON board_images.image_name = images.image_name
+ WHERE board_images.board_id = ?
+ ORDER BY board_images.updated_at DESC;
+ """,
+ (board_id,),
+ )
+ result = cast(list[sqlite3.Row], cursor.fetchall())
+ images = [deserialize_image_record(dict(r)) for r in result]
- cursor.execute(
- """--sql
- SELECT COUNT(*) FROM images WHERE 1=1;
- """
- )
- count = cast(int, cursor.fetchone()[0])
+ cursor.execute(
+ """--sql
+ SELECT COUNT(*) FROM images WHERE 1=1;
+ """
+ )
+ count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
@@ -90,56 +80,56 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
- params: list[str | bool] = []
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ params: list[str | bool] = []
- # Base query is a join between images and board_images
- stmt = """
- SELECT images.image_name
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE 1=1
- """
+ # Base query is a join between images and board_images
+ stmt = """
+ SELECT images.image_name
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1
+ """
- # Handle board_id filter
- if board_id == "none":
- stmt += """--sql
- AND board_images.board_id IS NULL
- """
- else:
- stmt += """--sql
- AND board_images.board_id = ?
- """
- params.append(board_id)
+ # Handle board_id filter
+ if board_id == "none":
+ stmt += """--sql
+ AND board_images.board_id IS NULL
+ """
+ else:
+ stmt += """--sql
+ AND board_images.board_id = ?
+ """
+ params.append(board_id)
- # Add the category filter
- if categories is not None:
- # Convert the enum values to unique list of strings
- category_strings = [c.value for c in set(categories)]
- # Create the correct length of placeholders
- placeholders = ",".join("?" * len(category_strings))
- stmt += f"""--sql
- AND images.image_category IN ( {placeholders} )
- """
+ # Add the category filter
+ if categories is not None:
+ # Convert the enum values to unique list of strings
+ category_strings = [c.value for c in set(categories)]
+ # Create the correct length of placeholders
+ placeholders = ",".join("?" * len(category_strings))
+ stmt += f"""--sql
+ AND images.image_category IN ( {placeholders} )
+ """
- # Unpack the included categories into the query params
- for c in category_strings:
- params.append(c)
+ # Unpack the included categories into the query params
+ for c in category_strings:
+ params.append(c)
- # Add the is_intermediate filter
- if is_intermediate is not None:
- stmt += """--sql
- AND images.is_intermediate = ?
- """
- params.append(is_intermediate)
+ # Add the is_intermediate filter
+ if is_intermediate is not None:
+ stmt += """--sql
+ AND images.is_intermediate = ?
+ """
+ params.append(is_intermediate)
- # Put a ring on it
- stmt += ";"
+ # Put a ring on it
+ stmt += ";"
- # Execute the query
- cursor = self._conn.cursor()
- cursor.execute(stmt, params)
+ cursor.execute(stmt, params)
- result = cast(list[sqlite3.Row], cursor.fetchall())
+ result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [r[0] for r in result]
return image_names
@@ -147,31 +137,33 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT board_id
- FROM board_images
- WHERE image_name = ?;
- """,
- (image_name,),
- )
- result = cursor.fetchone()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT board_id
+ FROM board_images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
+ result = cursor.fetchone()
if result is None:
return None
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT COUNT(*)
- FROM board_images
- INNER JOIN images ON board_images.image_name = images.image_name
- WHERE images.is_intermediate = FALSE
- AND board_images.board_id = ?;
- """,
- (board_id,),
- )
- count = cast(int, cursor.fetchone()[0])
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT COUNT(*)
+ FROM board_images
+ INNER JOIN images ON board_images.image_name = images.image_name
+ WHERE images.is_intermediate = FALSE
+ AND board_images.board_id = ?;
+ """,
+ (board_id,),
+ )
+ count = cast(int, cursor.fetchone()[0])
return count
diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py
index 86dd0e2a86..696ffab4b2 100644
--- a/invokeai/app/services/board_records/board_records_sqlite.py
+++ b/invokeai/app/services/board_records/board_records_sqlite.py
@@ -20,61 +20,60 @@ from invokeai.app.util.misc import uuid_string
class SqliteBoardRecordStorage(BoardRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def delete(self, board_id: str) -> None:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- DELETE FROM boards
- WHERE board_id = ?;
- """,
- (board_id,),
- )
- self._conn.commit()
- except Exception as e:
- self._conn.rollback()
- raise BoardRecordDeleteException from e
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ DELETE FROM boards
+ WHERE board_id = ?;
+ """,
+ (board_id,),
+ )
+ except Exception as e:
+ raise BoardRecordDeleteException from e
def save(
self,
board_name: str,
) -> BoardRecord:
- try:
- board_id = uuid_string()
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- INSERT OR IGNORE INTO boards (board_id, board_name)
- VALUES (?, ?);
- """,
- (board_id, board_name),
- )
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise BoardRecordSaveException from e
+ with self._db.conn() as conn:
+ try:
+ board_id = uuid_string()
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ INSERT OR IGNORE INTO boards (board_id, board_name)
+ VALUES (?, ?);
+ """,
+ (board_id, board_name),
+ )
+ except sqlite3.Error as e:
+ raise BoardRecordSaveException from e
return self.get(board_id)
def get(
self,
board_id: str,
) -> BoardRecord:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT *
- FROM boards
- WHERE board_id = ?;
- """,
- (board_id,),
- )
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT *
+ FROM boards
+ WHERE board_id = ?;
+ """,
+ (board_id,),
+ )
- result = cast(Union[sqlite3.Row, None], cursor.fetchone())
- except sqlite3.Error as e:
- raise BoardRecordNotFoundException from e
+ result = cast(Union[sqlite3.Row, None], cursor.fetchone())
+ except sqlite3.Error as e:
+ raise BoardRecordNotFoundException from e
if result is None:
raise BoardRecordNotFoundException
return BoardRecord(**dict(result))
@@ -84,45 +83,44 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
- try:
- cursor = self._conn.cursor()
- # Change the name of a board
- if changes.board_name is not None:
- cursor.execute(
- """--sql
- UPDATE boards
- SET board_name = ?
- WHERE board_id = ?;
- """,
- (changes.board_name, board_id),
- )
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ # Change the name of a board
+ if changes.board_name is not None:
+ cursor.execute(
+ """--sql
+ UPDATE boards
+ SET board_name = ?
+ WHERE board_id = ?;
+ """,
+ (changes.board_name, board_id),
+ )
- # Change the cover image of a board
- if changes.cover_image_name is not None:
- cursor.execute(
- """--sql
- UPDATE boards
- SET cover_image_name = ?
- WHERE board_id = ?;
- """,
- (changes.cover_image_name, board_id),
- )
+ # Change the cover image of a board
+ if changes.cover_image_name is not None:
+ cursor.execute(
+ """--sql
+ UPDATE boards
+ SET cover_image_name = ?
+ WHERE board_id = ?;
+ """,
+ (changes.cover_image_name, board_id),
+ )
- # Change the archived status of a board
- if changes.archived is not None:
- cursor.execute(
- """--sql
- UPDATE boards
- SET archived = ?
- WHERE board_id = ?;
- """,
- (changes.archived, board_id),
- )
+ # Change the archived status of a board
+ if changes.archived is not None:
+ cursor.execute(
+ """--sql
+ UPDATE boards
+ SET archived = ?
+ WHERE board_id = ?;
+ """,
+ (changes.archived, board_id),
+ )
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise BoardRecordSaveException from e
+ except sqlite3.Error as e:
+ raise BoardRecordSaveException from e
return self.get(board_id)
def get_many(
@@ -133,78 +131,80 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
- # Build base query
- base_query = """
- SELECT *
- FROM boards
- {archived_filter}
- ORDER BY {order_by} {direction}
- LIMIT ? OFFSET ?;
- """
-
- # Determine archived filter condition
- archived_filter = "" if include_archived else "WHERE archived = 0"
-
- final_query = base_query.format(
- archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
- )
-
- # Execute query to fetch boards
- cursor.execute(final_query, (limit, offset))
-
- result = cast(list[sqlite3.Row], cursor.fetchall())
- boards = [deserialize_board_record(dict(r)) for r in result]
-
- # Determine count query
- if include_archived:
- count_query = """
- SELECT COUNT(*)
- FROM boards;
- """
- else:
- count_query = """
- SELECT COUNT(*)
+ # Build base query
+ base_query = """
+ SELECT *
FROM boards
- WHERE archived = 0;
+ {archived_filter}
+ ORDER BY {order_by} {direction}
+ LIMIT ? OFFSET ?;
"""
- # Execute count query
- cursor.execute(count_query)
+ # Determine archived filter condition
+ archived_filter = "" if include_archived else "WHERE archived = 0"
- count = cast(int, cursor.fetchone()[0])
+ final_query = base_query.format(
+ archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
+ )
+
+ # Execute query to fetch boards
+ cursor.execute(final_query, (limit, offset))
+
+ result = cast(list[sqlite3.Row], cursor.fetchall())
+ boards = [deserialize_board_record(dict(r)) for r in result]
+
+ # Determine count query
+ if include_archived:
+ count_query = """
+ SELECT COUNT(*)
+ FROM boards;
+ """
+ else:
+ count_query = """
+ SELECT COUNT(*)
+ FROM boards
+ WHERE archived = 0;
+ """
+
+ # Execute count query
+ cursor.execute(count_query)
+
+ count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
- cursor = self._conn.cursor()
- if order_by == BoardRecordOrderBy.Name:
- base_query = """
- SELECT *
- FROM boards
- {archived_filter}
- ORDER BY LOWER(board_name) {direction}
- """
- else:
- base_query = """
- SELECT *
- FROM boards
- {archived_filter}
- ORDER BY {order_by} {direction}
- """
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ if order_by == BoardRecordOrderBy.Name:
+ base_query = """
+ SELECT *
+ FROM boards
+ {archived_filter}
+ ORDER BY LOWER(board_name) {direction}
+ """
+ else:
+ base_query = """
+ SELECT *
+ FROM boards
+ {archived_filter}
+ ORDER BY {order_by} {direction}
+ """
- archived_filter = "" if include_archived else "WHERE archived = 0"
+ archived_filter = "" if include_archived else "WHERE archived = 0"
- final_query = base_query.format(
- archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
- )
+ final_query = base_query.format(
+ archived_filter=archived_filter, order_by=order_by.value, direction=direction.value
+ )
- cursor.execute(final_query)
+ cursor.execute(final_query)
- result = cast(list[sqlite3.Row], cursor.fetchall())
+ result = cast(list[sqlite3.Row], cursor.fetchall())
boards = [deserialize_board_record(dict(r)) for r in result]
return boards
diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py
index 72c18feb98..78f725ba17 100644
--- a/invokeai/app/services/image_records/image_records_sqlite.py
+++ b/invokeai/app/services/image_records/image_records_sqlite.py
@@ -24,22 +24,23 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteImageRecordStorage(ImageRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def get(self, image_name: str) -> ImageRecord:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- f"""--sql
- SELECT {IMAGE_DTO_COLS} FROM images
- WHERE image_name = ?;
- """,
- (image_name,),
- )
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ f"""--sql
+ SELECT {IMAGE_DTO_COLS} FROM images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
- result = cast(Optional[sqlite3.Row], cursor.fetchone())
- except sqlite3.Error as e:
- raise ImageRecordNotFoundException from e
+ result = cast(Optional[sqlite3.Row], cursor.fetchone())
+ except sqlite3.Error as e:
+ raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -47,17 +48,21 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT metadata FROM images
- WHERE image_name = ?;
- """,
- (image_name,),
- )
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT metadata FROM images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
- result = cast(Optional[sqlite3.Row], cursor.fetchone())
+ result = cast(Optional[sqlite3.Row], cursor.fetchone())
+
+ except sqlite3.Error as e:
+ raise ImageRecordNotFoundException from e
if not result:
raise ImageRecordNotFoundException
@@ -65,64 +70,61 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
as_dict = dict(result)
metadata_raw = cast(Optional[str], as_dict.get("metadata", None))
return MetadataFieldValidator.validate_json(metadata_raw) if metadata_raw is not None else None
- except sqlite3.Error as e:
- raise ImageRecordNotFoundException from e
def update(
self,
image_name: str,
changes: ImageRecordChanges,
) -> None:
- try:
- cursor = self._conn.cursor()
- # Change the category of the image
- if changes.image_category is not None:
- cursor.execute(
- """--sql
- UPDATE images
- SET image_category = ?
- WHERE image_name = ?;
- """,
- (changes.image_category, image_name),
- )
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ # Change the category of the image
+ if changes.image_category is not None:
+ cursor.execute(
+ """--sql
+ UPDATE images
+ SET image_category = ?
+ WHERE image_name = ?;
+ """,
+ (changes.image_category, image_name),
+ )
- # Change the session associated with the image
- if changes.session_id is not None:
- cursor.execute(
- """--sql
- UPDATE images
- SET session_id = ?
- WHERE image_name = ?;
- """,
- (changes.session_id, image_name),
- )
+ # Change the session associated with the image
+ if changes.session_id is not None:
+ cursor.execute(
+ """--sql
+ UPDATE images
+ SET session_id = ?
+ WHERE image_name = ?;
+ """,
+ (changes.session_id, image_name),
+ )
- # Change the image's `is_intermediate`` flag
- if changes.is_intermediate is not None:
- cursor.execute(
- """--sql
- UPDATE images
- SET is_intermediate = ?
- WHERE image_name = ?;
- """,
- (changes.is_intermediate, image_name),
- )
+ # Change the image's `is_intermediate`` flag
+ if changes.is_intermediate is not None:
+ cursor.execute(
+ """--sql
+ UPDATE images
+ SET is_intermediate = ?
+ WHERE image_name = ?;
+ """,
+ (changes.is_intermediate, image_name),
+ )
- # Change the image's `starred`` state
- if changes.starred is not None:
- cursor.execute(
- """--sql
- UPDATE images
- SET starred = ?
- WHERE image_name = ?;
- """,
- (changes.starred, image_name),
- )
+ # Change the image's `starred`` state
+ if changes.starred is not None:
+ cursor.execute(
+ """--sql
+ UPDATE images
+ SET starred = ?
+ WHERE image_name = ?;
+ """,
+ (changes.starred, image_name),
+ )
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise ImageRecordSaveException from e
+ except sqlite3.Error as e:
+ raise ImageRecordSaveException from e
def get_many(
self,
@@ -136,94 +138,95 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
- # Manually build two queries - one for the count, one for the records
- count_query = """--sql
- SELECT COUNT(*)
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE 1=1
- """
-
- images_query = f"""--sql
- SELECT {IMAGE_DTO_COLS}
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE 1=1
- """
-
- query_conditions = ""
- query_params: list[Union[int, str, bool]] = []
-
- if image_origin is not None:
- query_conditions += """--sql
- AND images.image_origin = ?
- """
- query_params.append(image_origin.value)
-
- if categories is not None:
- # Convert the enum values to unique list of strings
- category_strings = [c.value for c in set(categories)]
- # Create the correct length of placeholders
- placeholders = ",".join("?" * len(category_strings))
-
- query_conditions += f"""--sql
- AND images.image_category IN ( {placeholders} )
+ # Manually build two queries - one for the count, one for the records
+ count_query = """--sql
+ SELECT COUNT(*)
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1
"""
- # Unpack the included categories into the query params
- for c in category_strings:
- query_params.append(c)
-
- if is_intermediate is not None:
- query_conditions += """--sql
- AND images.is_intermediate = ?
+ images_query = f"""--sql
+ SELECT {IMAGE_DTO_COLS}
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1
"""
- query_params.append(is_intermediate)
+ query_conditions = ""
+ query_params: list[Union[int, str, bool]] = []
- # board_id of "none" is reserved for images without a board
- if board_id == "none":
- query_conditions += """--sql
- AND board_images.board_id IS NULL
- """
- elif board_id is not None:
- query_conditions += """--sql
- AND board_images.board_id = ?
- """
- query_params.append(board_id)
+ if image_origin is not None:
+ query_conditions += """--sql
+ AND images.image_origin = ?
+ """
+ query_params.append(image_origin.value)
- # Search term condition
- if search_term:
- query_conditions += """--sql
- AND (
- images.metadata LIKE ?
- OR images.created_at LIKE ?
- )
- """
- query_params.append(f"%{search_term.lower()}%")
- query_params.append(f"%{search_term.lower()}%")
+ if categories is not None:
+ # Convert the enum values to unique list of strings
+ category_strings = [c.value for c in set(categories)]
+ # Create the correct length of placeholders
+ placeholders = ",".join("?" * len(category_strings))
- if starred_first:
- query_pagination = f"""--sql
- ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
- """
- else:
- query_pagination = f"""--sql
- ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
- """
+ query_conditions += f"""--sql
+ AND images.image_category IN ( {placeholders} )
+ """
- # Final images query with pagination
- images_query += query_conditions + query_pagination + ";"
- # Add all the parameters
- images_params = query_params.copy()
- # Add the pagination parameters
- images_params.extend([limit, offset])
+ # Unpack the included categories into the query params
+ for c in category_strings:
+ query_params.append(c)
- # Build the list of images, deserializing each row
- cursor.execute(images_query, images_params)
- result = cast(list[sqlite3.Row], cursor.fetchall())
+ if is_intermediate is not None:
+ query_conditions += """--sql
+ AND images.is_intermediate = ?
+ """
+
+ query_params.append(is_intermediate)
+
+ # board_id of "none" is reserved for images without a board
+ if board_id == "none":
+ query_conditions += """--sql
+ AND board_images.board_id IS NULL
+ """
+ elif board_id is not None:
+ query_conditions += """--sql
+ AND board_images.board_id = ?
+ """
+ query_params.append(board_id)
+
+ # Search term condition
+ if search_term:
+ query_conditions += """--sql
+ AND (
+ images.metadata LIKE ?
+ OR images.created_at LIKE ?
+ )
+ """
+ query_params.append(f"%{search_term.lower()}%")
+ query_params.append(f"%{search_term.lower()}%")
+
+ if starred_first:
+ query_pagination = f"""--sql
+ ORDER BY images.starred DESC, images.created_at {order_dir.value} LIMIT ? OFFSET ?
+ """
+ else:
+ query_pagination = f"""--sql
+ ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ?
+ """
+
+ # Final images query with pagination
+ images_query += query_conditions + query_pagination + ";"
+ # Add all the parameters
+ images_params = query_params.copy()
+ # Add the pagination parameters
+ images_params.extend([limit, offset])
+
+ # Build the list of images, deserializing each row
+ cursor.execute(images_query, images_params)
+ result = cast(list[sqlite3.Row], cursor.fetchall())
images = [deserialize_image_record(dict(r)) for r in result]
# Set up and execute the count query, without pagination
@@ -235,71 +238,68 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- DELETE FROM images
- WHERE image_name = ?;
- """,
- (image_name,),
- )
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise ImageRecordDeleteException from e
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ DELETE FROM images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
+ except sqlite3.Error as e:
+ raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
- placeholders = ",".join("?" for _ in image_names)
+ placeholders = ",".join("?" for _ in image_names)
- # Construct the SQLite query with the placeholders
- query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
+ # Construct the SQLite query with the placeholders
+ query = f"DELETE FROM images WHERE image_name IN ({placeholders})"
- # Execute the query with the list of IDs as parameters
- cursor.execute(query, image_names)
+ # Execute the query with the list of IDs as parameters
+ cursor.execute(query, image_names)
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise ImageRecordDeleteException from e
+ except sqlite3.Error as e:
+ raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT COUNT(*) FROM images
- WHERE is_intermediate = TRUE;
- """
- )
- count = cast(int, cursor.fetchone()[0])
- self._conn.commit()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT COUNT(*) FROM images
+ WHERE is_intermediate = TRUE;
+ """
+ )
+ count = cast(int, cursor.fetchone()[0])
return count
def delete_intermediates(self) -> list[str]:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT image_name FROM images
- WHERE is_intermediate = TRUE;
- """
- )
- result = cast(list[sqlite3.Row], cursor.fetchall())
- image_names = [r[0] for r in result]
- cursor.execute(
- """--sql
- DELETE FROM images
- WHERE is_intermediate = TRUE;
- """
- )
- self._conn.commit()
- return image_names
- except sqlite3.Error as e:
- self._conn.rollback()
- raise ImageRecordDeleteException from e
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT image_name FROM images
+ WHERE is_intermediate = TRUE;
+ """
+ )
+ result = cast(list[sqlite3.Row], cursor.fetchall())
+ image_names = [r[0] for r in result]
+ cursor.execute(
+ """--sql
+ DELETE FROM images
+ WHERE is_intermediate = TRUE;
+ """
+ )
+ except sqlite3.Error as e:
+ raise ImageRecordDeleteException from e
+ return image_names
def save(
self,
@@ -315,73 +315,73 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
- try:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- INSERT OR IGNORE INTO images (
- image_name,
- image_origin,
- image_category,
- width,
- height,
- node_id,
- session_id,
- metadata,
- is_intermediate,
- starred,
- has_workflow
- )
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
- """,
- (
- image_name,
- image_origin.value,
- image_category.value,
- width,
- height,
- node_id,
- session_id,
- metadata,
- is_intermediate,
- starred,
- has_workflow,
- ),
- )
- self._conn.commit()
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ INSERT OR IGNORE INTO images (
+ image_name,
+ image_origin,
+ image_category,
+ width,
+ height,
+ node_id,
+ session_id,
+ metadata,
+ is_intermediate,
+ starred,
+ has_workflow
+ )
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);
+ """,
+ (
+ image_name,
+ image_origin.value,
+ image_category.value,
+ width,
+ height,
+ node_id,
+ session_id,
+ metadata,
+ is_intermediate,
+ starred,
+ has_workflow,
+ ),
+ )
- cursor.execute(
- """--sql
- SELECT created_at
- FROM images
- WHERE image_name = ?;
- """,
- (image_name,),
- )
+ cursor.execute(
+ """--sql
+ SELECT created_at
+ FROM images
+ WHERE image_name = ?;
+ """,
+ (image_name,),
+ )
- created_at = datetime.fromisoformat(cursor.fetchone()[0])
+ created_at = datetime.fromisoformat(cursor.fetchone()[0])
- return created_at
- except sqlite3.Error as e:
- self._conn.rollback()
- raise ImageRecordSaveException from e
+ except sqlite3.Error as e:
+ raise ImageRecordSaveException from e
+ return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT images.*
- FROM images
- JOIN board_images ON images.image_name = board_images.image_name
- WHERE board_images.board_id = ?
- AND images.is_intermediate = FALSE
- ORDER BY images.starred DESC, images.created_at DESC
- LIMIT 1;
- """,
- (board_id,),
- )
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT images.*
+ FROM images
+ JOIN board_images ON images.image_name = board_images.image_name
+ WHERE board_images.board_id = ?
+ AND images.is_intermediate = FALSE
+ ORDER BY images.starred DESC, images.created_at DESC
+ LIMIT 1;
+ """,
+ (board_id,),
+ )
- result = cast(Optional[sqlite3.Row], cursor.fetchone())
+ result = cast(Optional[sqlite3.Row], cursor.fetchone())
if result is None:
return None
@@ -398,85 +398,86 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
- # Build query conditions (reused for both starred count and image names queries)
- query_conditions = ""
- query_params: list[Union[int, str, bool]] = []
+ # Build query conditions (reused for both starred count and image names queries)
+ query_conditions = ""
+ query_params: list[Union[int, str, bool]] = []
- if image_origin is not None:
- query_conditions += """--sql
- AND images.image_origin = ?
- """
- query_params.append(image_origin.value)
+ if image_origin is not None:
+ query_conditions += """--sql
+ AND images.image_origin = ?
+ """
+ query_params.append(image_origin.value)
- if categories is not None:
- category_strings = [c.value for c in set(categories)]
- placeholders = ",".join("?" * len(category_strings))
- query_conditions += f"""--sql
- AND images.image_category IN ( {placeholders} )
- """
- for c in category_strings:
- query_params.append(c)
+ if categories is not None:
+ category_strings = [c.value for c in set(categories)]
+ placeholders = ",".join("?" * len(category_strings))
+ query_conditions += f"""--sql
+ AND images.image_category IN ( {placeholders} )
+ """
+ for c in category_strings:
+ query_params.append(c)
- if is_intermediate is not None:
- query_conditions += """--sql
- AND images.is_intermediate = ?
- """
- query_params.append(is_intermediate)
+ if is_intermediate is not None:
+ query_conditions += """--sql
+ AND images.is_intermediate = ?
+ """
+ query_params.append(is_intermediate)
- if board_id == "none":
- query_conditions += """--sql
- AND board_images.board_id IS NULL
- """
- elif board_id is not None:
- query_conditions += """--sql
- AND board_images.board_id = ?
- """
- query_params.append(board_id)
+ if board_id == "none":
+ query_conditions += """--sql
+ AND board_images.board_id IS NULL
+ """
+ elif board_id is not None:
+ query_conditions += """--sql
+ AND board_images.board_id = ?
+ """
+ query_params.append(board_id)
- if search_term:
- query_conditions += """--sql
- AND (
- images.metadata LIKE ?
- OR images.created_at LIKE ?
- )
- """
- query_params.append(f"%{search_term.lower()}%")
- query_params.append(f"%{search_term.lower()}%")
+ if search_term:
+ query_conditions += """--sql
+ AND (
+ images.metadata LIKE ?
+ OR images.created_at LIKE ?
+ )
+ """
+ query_params.append(f"%{search_term.lower()}%")
+ query_params.append(f"%{search_term.lower()}%")
- # Get starred count if starred_first is enabled
- starred_count = 0
- if starred_first:
- starred_count_query = f"""--sql
- SELECT COUNT(*)
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE images.starred = TRUE AND (1=1{query_conditions})
- """
- cursor.execute(starred_count_query, query_params)
- starred_count = cast(int, cursor.fetchone()[0])
+ # Get starred count if starred_first is enabled
+ starred_count = 0
+ if starred_first:
+ starred_count_query = f"""--sql
+ SELECT COUNT(*)
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE images.starred = TRUE AND (1=1{query_conditions})
+ """
+ cursor.execute(starred_count_query, query_params)
+ starred_count = cast(int, cursor.fetchone()[0])
- # Get all image names with proper ordering
- if starred_first:
- names_query = f"""--sql
- SELECT images.image_name
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE 1=1{query_conditions}
- ORDER BY images.starred DESC, images.created_at {order_dir.value}
- """
- else:
- names_query = f"""--sql
- SELECT images.image_name
- FROM images
- LEFT JOIN board_images ON board_images.image_name = images.image_name
- WHERE 1=1{query_conditions}
- ORDER BY images.created_at {order_dir.value}
- """
+ # Get all image names with proper ordering
+ if starred_first:
+ names_query = f"""--sql
+ SELECT images.image_name
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1{query_conditions}
+ ORDER BY images.starred DESC, images.created_at {order_dir.value}
+ """
+ else:
+ names_query = f"""--sql
+ SELECT images.image_name
+ FROM images
+ LEFT JOIN board_images ON board_images.image_name = images.image_name
+ WHERE 1=1{query_conditions}
+ ORDER BY images.created_at {order_dir.value}
+ """
- cursor.execute(names_query, query_params)
- result = cast(list[sqlite3.Row], cursor.fetchall())
+ cursor.execute(names_query, query_params)
+ result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return ImageNamesResult(image_names=image_names, starred_count=starred_count, total_count=len(image_names))
diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py
index 04a3950d6c..fa8b7c0ff1 100644
--- a/invokeai/app/services/model_records/model_records_sql.py
+++ b/invokeai/app/services/model_records/model_records_sql.py
@@ -78,11 +78,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self._db = db
self._logger = logger
- @property
- def db(self) -> SqliteDatabase:
- """Return the underlying database."""
- return self._db
-
def add_model(self, config: AnyModelConfig) -> AnyModelConfig:
"""
Add a model to the database.
@@ -93,38 +88,34 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
- try:
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- INSERT INTO models (
- id,
- config
- )
- VALUES (?,?);
- """,
- (
- config.key,
- config.model_dump_json(),
- ),
- )
- self._db.conn.commit()
+ with self._db.conn() as conn:
+ try:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ INSERT INTO models (
+ id,
+ config
+ )
+ VALUES (?,?);
+ """,
+ (
+ config.key,
+ config.model_dump_json(),
+ ),
+ )
- except sqlite3.IntegrityError as e:
- self._db.conn.rollback()
- if "UNIQUE constraint failed" in str(e):
- if "models.path" in str(e):
- msg = f"A model with path '{config.path}' is already installed"
- elif "models.name" in str(e):
- msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
+ except sqlite3.IntegrityError as e:
+ if "UNIQUE constraint failed" in str(e):
+ if "models.path" in str(e):
+ msg = f"A model with path '{config.path}' is already installed"
+ elif "models.name" in str(e):
+ msg = f"A model with name='{config.name}', type='{config.type}', base='{config.base}' is already installed"
+ else:
+ msg = f"A model with key '{config.key}' is already installed"
+ raise DuplicateModelException(msg) from e
else:
- msg = f"A model with key '{config.key}' is already installed"
- raise DuplicateModelException(msg) from e
- else:
- raise e
- except sqlite3.Error as e:
- self._db.conn.rollback()
- raise e
+ raise e
return self.get_model(config.key)
@@ -136,8 +127,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
- try:
- cursor = self._db.conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM models
@@ -147,22 +138,18 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
- self._db.conn.commit()
- except sqlite3.Error as e:
- self._db.conn.rollback()
- raise e
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
- record = self.get_model(key)
+ with self._db.conn() as conn:
+ record = self.get_model(key)
- # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
- for field_name in changes.model_fields_set:
- setattr(record, field_name, getattr(changes, field_name))
+ # Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
+ for field_name in changes.model_fields_set:
+ setattr(record, field_name, getattr(changes, field_name))
- json_serialized = record.model_dump_json()
+ json_serialized = record.model_dump_json()
- try:
- cursor = self._db.conn.cursor()
+ cursor = conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -174,10 +161,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
)
if cursor.rowcount == 0:
raise UnknownModelException("model not found")
- self._db.conn.commit()
- except sqlite3.Error as e:
- self._db.conn.rollback()
- raise e
return self.get_model(key)
@@ -189,30 +172,32 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- SELECT config, strftime('%s',updated_at) FROM models
- WHERE id=?;
- """,
- (key,),
- )
- rows = cursor.fetchone()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT config, strftime('%s',updated_at) FROM models
+ WHERE id=?;
+ """,
+ (key,),
+ )
+ rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- SELECT config, strftime('%s',updated_at) FROM models
- WHERE hash=?;
- """,
- (hash,),
- )
- rows = cursor.fetchone()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT config, strftime('%s',updated_at) FROM models
+ WHERE hash=?;
+ """,
+ (hash,),
+ )
+ rows = cursor.fetchone()
if not rows:
raise UnknownModelException("model not found")
model = ModelConfigFactory.make_config(json.loads(rows[0]), timestamp=rows[1])
@@ -224,15 +209,16 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- select count(*) FROM models
- WHERE id=?;
- """,
- (key,),
- )
- count = cursor.fetchone()[0]
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ select count(*) FROM models
+ WHERE id=?;
+ """,
+ (key,),
+ )
+ count = cursor.fetchone()[0]
return count > 0
def search_by_attr(
@@ -255,43 +241,43 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
+ with self._db.conn() as conn:
+ assert isinstance(order_by, ModelRecordOrderBy)
+ ordering = {
+ ModelRecordOrderBy.Default: "type, base, name, format",
+ ModelRecordOrderBy.Type: "type",
+ ModelRecordOrderBy.Base: "base",
+ ModelRecordOrderBy.Name: "name",
+ ModelRecordOrderBy.Format: "format",
+ }
- assert isinstance(order_by, ModelRecordOrderBy)
- ordering = {
- ModelRecordOrderBy.Default: "type, base, name, format",
- ModelRecordOrderBy.Type: "type",
- ModelRecordOrderBy.Base: "base",
- ModelRecordOrderBy.Name: "name",
- ModelRecordOrderBy.Format: "format",
- }
+ where_clause: list[str] = []
+ bindings: list[str] = []
+ if model_name:
+ where_clause.append("name=?")
+ bindings.append(model_name)
+ if base_model:
+ where_clause.append("base=?")
+ bindings.append(base_model)
+ if model_type:
+ where_clause.append("type=?")
+ bindings.append(model_type)
+ if model_format:
+ where_clause.append("format=?")
+ bindings.append(model_format)
+ where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
- where_clause: list[str] = []
- bindings: list[str] = []
- if model_name:
- where_clause.append("name=?")
- bindings.append(model_name)
- if base_model:
- where_clause.append("base=?")
- bindings.append(base_model)
- if model_type:
- where_clause.append("type=?")
- bindings.append(model_type)
- if model_format:
- where_clause.append("format=?")
- bindings.append(model_format)
- where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
-
- cursor = self._db.conn.cursor()
- cursor.execute(
- f"""--sql
- SELECT config, strftime('%s',updated_at)
- FROM models
- {where}
- ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
- """,
- tuple(bindings),
- )
- result = cursor.fetchall()
+ cursor = conn.cursor()
+ cursor.execute(
+ f"""--sql
+ SELECT config, strftime('%s',updated_at)
+ FROM models
+ {where}
+ ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason;
+ """,
+ tuple(bindings),
+ )
+ result = cursor.fetchall()
# Parse the model configs.
results: list[AnyModelConfig] = []
@@ -313,69 +299,72 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- SELECT config, strftime('%s',updated_at) FROM models
- WHERE path=?;
- """,
- (str(path),),
- )
- results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT config, strftime('%s',updated_at) FROM models
+ WHERE path=?;
+ """,
+ (str(path),),
+ )
+ results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
- cursor = self._db.conn.cursor()
- cursor.execute(
- """--sql
- SELECT config, strftime('%s',updated_at) FROM models
- WHERE hash=?;
- """,
- (hash,),
- )
- results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT config, strftime('%s',updated_at) FROM models
+ WHERE hash=?;
+ """,
+ (hash,),
+ )
+ results = [ModelConfigFactory.make_config(json.loads(x[0]), timestamp=x[1]) for x in cursor.fetchall()]
return results
def list_models(
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
- assert isinstance(order_by, ModelRecordOrderBy)
- ordering = {
- ModelRecordOrderBy.Default: "type, base, name, format",
- ModelRecordOrderBy.Type: "type",
- ModelRecordOrderBy.Base: "base",
- ModelRecordOrderBy.Name: "name",
- ModelRecordOrderBy.Format: "format",
- }
+ with self._db.conn() as conn:
+ assert isinstance(order_by, ModelRecordOrderBy)
+ ordering = {
+ ModelRecordOrderBy.Default: "type, base, name, format",
+ ModelRecordOrderBy.Type: "type",
+ ModelRecordOrderBy.Base: "base",
+ ModelRecordOrderBy.Name: "name",
+ ModelRecordOrderBy.Format: "format",
+ }
- cursor = self._db.conn.cursor()
+ cursor = conn.cursor()
- # Lock so that the database isn't updated while we're doing the two queries.
- # query1: get the total number of model configs
- cursor.execute(
- """--sql
- select count(*) from models;
- """,
- (),
- )
- total = int(cursor.fetchone()[0])
+ # Lock so that the database isn't updated while we're doing the two queries.
+ # query1: get the total number of model configs
+ cursor.execute(
+ """--sql
+ select count(*) from models;
+ """,
+ (),
+ )
+ total = int(cursor.fetchone()[0])
- # query2: fetch key fields
- cursor.execute(
- f"""--sql
- SELECT config
- FROM models
- ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
- LIMIT ?
- OFFSET ?;
- """,
- (
- per_page,
- page * per_page,
- ),
- )
- rows = cursor.fetchall()
+ # query2: fetch key fields
+ cursor.execute(
+ f"""--sql
+ SELECT config
+ FROM models
+ ORDER BY {ordering[order_by]} -- using ? to bind doesn't work here for some reason
+ LIMIT ?
+ OFFSET ?;
+ """,
+ (
+ per_page,
+ page * per_page,
+ ),
+ )
+ rows = cursor.fetchall()
items = [ModelSummary.model_validate(dict(x)) for x in rows]
return PaginatedResults(page=page, pages=ceil(total / per_page), per_page=per_page, total=total, items=items)
diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
index 87890cb36e..aa429351aa 100644
--- a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
+++ b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
@@ -1,5 +1,3 @@
-import sqlite3
-
from invokeai.app.services.model_relationship_records.model_relationship_records_base import (
ModelRelationshipRecordStorageBase,
)
@@ -9,58 +7,54 @@ from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
- if model_key_1 == model_key_2:
- raise ValueError("Cannot relate a model to itself.")
- a, b = sorted([model_key_1, model_key_2])
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ if model_key_1 == model_key_2:
+ raise ValueError("Cannot relate a model to itself.")
+ a, b = sorted([model_key_1, model_key_2])
+ cursor = conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise e
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
- a, b = sorted([model_key_1, model_key_2])
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ a, b = sorted([model_key_1, model_key_2])
+ cursor = conn.cursor()
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
- self._conn.commit()
- except sqlite3.Error as e:
- self._conn.rollback()
- raise e
def get_related_model_keys(self, model_key: str) -> list[str]:
- cursor = self._conn.cursor()
- cursor.execute(
- """
- SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
- UNION
- SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
- """,
- (model_key, model_key),
- )
- return [row[0] for row in cursor.fetchall()]
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """
+ SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
+ UNION
+ SELECT model_key_1 FROM model_relationships WHERE model_key_2 = ?
+ """,
+ (model_key, model_key),
+ )
+ result = [row[0] for row in cursor.fetchall()]
+ return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
- key_list = ",".join("?" for _ in model_keys)
- cursor.execute(
- f"""
- SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
- UNION
- SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
- """,
- model_keys + model_keys,
- )
- return [row[0] for row in cursor.fetchall()]
+ key_list = ",".join("?" for _ in model_keys)
+ cursor.execute(
+ f"""
+ SELECT model_key_2 FROM model_relationships WHERE model_key_1 IN ({key_list})
+ UNION
+ SELECT model_key_1 FROM model_relationships WHERE model_key_2 IN ({key_list})
+ """,
+ model_keys + model_keys,
+ )
+ result = [row[0] for row in cursor.fetchall()]
+ return result
diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py
index 886f2ea205..491ec6c20c 100644
--- a/invokeai/app/services/session_queue/session_queue_sqlite.py
+++ b/invokeai/app/services/session_queue/session_queue_sqlite.py
@@ -50,103 +50,97 @@ class SqliteSessionQueue(SessionQueueBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def _set_in_progress_to_canceled(self) -> None:
"""
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
- try:
- cursor = self._conn.cursor()
- cursor.execute(
+ with self._db.conn() as conn:
+ conn.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
WHERE status = 'in_progress';
"""
)
- except Exception:
- self._conn.rollback()
- raise
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT count(*)
- FROM session_queue
- WHERE
- queue_id = ?
- AND status = 'pending'
- """,
- (queue_id,),
- )
- return cast(int, cursor.fetchone()[0])
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT count(*)
+ FROM session_queue
+ WHERE
+ queue_id = ?
+ AND status = 'pending'
+ """,
+ (queue_id,),
+ )
+ count = cast(int, cursor.fetchone()[0])
+ return count
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT MAX(priority)
- FROM session_queue
- WHERE
- queue_id = ?
- AND status = 'pending'
- """,
- (queue_id,),
- )
- return cast(Union[int, None], cursor.fetchone()[0]) or 0
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT MAX(priority)
+ FROM session_queue
+ WHERE
+ queue_id = ?
+ AND status = 'pending'
+ """,
+ (queue_id,),
+ )
+ priority = cast(Union[int, None], cursor.fetchone()[0]) or 0
+ return priority
async def enqueue_batch(self, queue_id: str, batch: Batch, prepend: bool) -> EnqueueBatchResult:
- try:
- # TODO: how does this work in a multi-user scenario?
- current_queue_size = self._get_current_queue_size(queue_id)
- max_queue_size = self.__invoker.services.configuration.max_queue_size
- max_new_queue_items = max_queue_size - current_queue_size
+ current_queue_size = self._get_current_queue_size(queue_id)
+ max_queue_size = self.__invoker.services.configuration.max_queue_size
+ max_new_queue_items = max_queue_size - current_queue_size
- priority = 0
- if prepend:
- priority = self._get_highest_priority(queue_id) + 1
+ priority = 0
+ if prepend:
+ priority = self._get_highest_priority(queue_id) + 1
- requested_count = await asyncio.to_thread(
- calc_session_count,
- batch=batch,
- )
- values_to_insert = await asyncio.to_thread(
- prepare_values_to_insert,
- queue_id=queue_id,
- batch=batch,
- priority=priority,
- max_new_queue_items=max_new_queue_items,
- )
- enqueued_count = len(values_to_insert)
+ requested_count = await asyncio.to_thread(
+ calc_session_count,
+ batch=batch,
+ )
+ values_to_insert = await asyncio.to_thread(
+ prepare_values_to_insert,
+ queue_id=queue_id,
+ batch=batch,
+ priority=priority,
+ max_new_queue_items=max_new_queue_items,
+ )
+ enqueued_count = len(values_to_insert)
- with self._conn:
- cursor = self._conn.cursor()
- cursor.executemany(
- """--sql
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.executemany(
+ """--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
- values_to_insert,
- )
- with self._conn:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
+ values_to_insert,
+ )
+ cursor.execute(
+ """--sql
SELECT item_id
FROM session_queue
WHERE batch_id = ?
ORDER BY item_id DESC;
""",
- (batch.batch_id,),
- )
- item_ids = [row[0] for row in cursor.fetchall()]
- except Exception:
- raise
+ (batch.batch_id,),
+ )
+ item_ids = [row[0] for row in cursor.fetchall()]
enqueue_result = EnqueueBatchResult(
queue_id=queue_id,
requested=requested_count,
@@ -159,19 +153,20 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT *
- FROM session_queue
- WHERE status = 'pending'
- ORDER BY
- priority DESC,
- item_id ASC
- LIMIT 1
- """
- )
- result = cast(Union[sqlite3.Row, None], cursor.fetchone())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT *
+ FROM session_queue
+ WHERE status = 'pending'
+ ORDER BY
+ priority DESC,
+ item_id ASC
+ LIMIT 1
+ """
+ )
+ result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
queue_item = SessionQueueItem.queue_item_from_dict(dict(result))
@@ -179,40 +174,42 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT *
- FROM session_queue
- WHERE
- queue_id = ?
- AND status = 'pending'
- ORDER BY
- priority DESC,
- created_at ASC
- LIMIT 1
- """,
- (queue_id,),
- )
- result = cast(Union[sqlite3.Row, None], cursor.fetchone())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT *
+ FROM session_queue
+ WHERE
+ queue_id = ?
+ AND status = 'pending'
+ ORDER BY
+ priority DESC,
+ created_at ASC
+ LIMIT 1
+ """,
+ (queue_id,),
+ )
+ result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT *
- FROM session_queue
- WHERE
- queue_id = ?
- AND status = 'in_progress'
- LIMIT 1
- """,
- (queue_id,),
- )
- result = cast(Union[sqlite3.Row, None], cursor.fetchone())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT *
+ FROM session_queue
+ WHERE
+ queue_id = ?
+ AND status = 'in_progress'
+ LIMIT 1
+ """,
+ (queue_id,),
+ )
+ result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
return None
return SessionQueueItem.queue_item_from_dict(dict(result))
@@ -225,8 +222,8 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -234,12 +231,16 @@ class SqliteSessionQueue(SessionQueueBase):
(item_id,),
)
row = cursor.fetchone()
- if row is None:
- raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
- current_status = row[0]
- # Only update if not already finished (completed, failed or canceled)
- if current_status in ("completed", "failed", "canceled"):
- return self.get_queue_item(item_id)
+ if row is None:
+ raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
+ current_status = row[0]
+
+ # Only update if not already finished (completed, failed or canceled)
+ if current_status in ("completed", "failed", "canceled"):
+ return self.get_queue_item(item_id)
+
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
UPDATE session_queue
@@ -248,10 +249,7 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(status, error_type, error_message, error_traceback, item_id),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
+
queue_item = self.get_queue_item(item_id)
batch_status = self.get_batch_status(queue_id=queue_item.queue_id, batch_id=queue_item.batch_id)
queue_status = self.get_queue_status(queue_id=queue_item.queue_id)
@@ -259,35 +257,37 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT count(*)
- FROM session_queue
- WHERE queue_id = ?
- """,
- (queue_id,),
- )
- is_empty = cast(int, cursor.fetchone()[0]) == 0
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ """,
+ (queue_id,),
+ )
+ is_empty = cast(int, cursor.fetchone()[0]) == 0
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT count(*)
- FROM session_queue
- WHERE queue_id = ?
- """,
- (queue_id,),
- )
- max_queue_size = self.__invoker.services.configuration.max_queue_size
- is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ """,
+ (queue_id,),
+ )
+ max_queue_size = self.__invoker.services.configuration.max_queue_size
+ is_full = cast(int, cursor.fetchone()[0]) >= max_queue_size
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -305,24 +305,20 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
self.__invoker.services.events.emit_queue_cleared(queue_id)
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
where = """--sql
WHERE
- queue_id = ?
- AND (
+ queue_id = ?
+ AND (
status = 'completed'
OR status = 'failed'
OR status = 'canceled'
- )
+ )
"""
cursor.execute(
f"""--sql
@@ -341,10 +337,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return PruneResult(deleted=count)
def cancel_queue_item(self, item_id: int) -> SessionQueueItem:
@@ -357,8 +349,8 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE
@@ -367,10 +359,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(item_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
def complete_queue_item(self, item_id: int) -> SessionQueueItem:
queue_item = self._set_queue_item_status(item_id=item_id, status="completed")
@@ -393,8 +381,8 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -425,17 +413,15 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
- self._conn.commit()
- if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
- self._set_queue_item_status(current_queue_item.item_id, "canceled")
- except Exception:
- self._conn.rollback()
- raise
+
+ if current_queue_item is not None and current_queue_item.batch_id in batch_ids:
+ self._set_queue_item_status(current_queue_item.item_id, "canceled")
+
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -465,17 +451,13 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
- self._conn.commit()
- if current_queue_item is not None and current_queue_item.destination == destination:
- self._set_queue_item_status(current_queue_item.item_id, "canceled")
- except Exception:
- self._conn.rollback()
- raise
+ if current_queue_item is not None and current_queue_item.destination == destination:
+ self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -501,15 +483,11 @@ class SqliteSessionQueue(SessionQueueBase):
""",
params,
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
where = """--sql
WHERE
queue_id == ?
@@ -532,15 +510,11 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -569,18 +543,14 @@ class SqliteSessionQueue(SessionQueueBase):
""",
tuple(params),
)
- self._conn.commit()
- if current_queue_item is not None and current_queue_item.queue_id == queue_id:
- self._set_queue_item_status(current_queue_item.item_id, "canceled")
- except Exception:
- self._conn.rollback()
- raise
+ if current_queue_item is not None and current_queue_item.queue_id == queue_id:
+ self._set_queue_item_status(current_queue_item.item_id, "canceled")
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
where = """--sql
WHERE
queue_id == ?
@@ -603,30 +573,27 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(queue_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT * FROM session_queue
- WHERE
- item_id = ?
- """,
- (item_id,),
- )
- result = cast(Union[sqlite3.Row, None], cursor.fetchone())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT * FROM session_queue
+ WHERE
+ item_id = ?
+ """,
+ (item_id,),
+ )
+ result = cast(Union[sqlite3.Row, None], cursor.fetchone())
if result is None:
raise SessionQueueItemNotFoundError(f"No queue item with id {item_id}")
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -639,10 +606,6 @@ class SqliteSessionQueue(SessionQueueBase):
""",
(session_json, item_id),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return self.get_queue_item(item_id)
def list_queue_items(
@@ -654,42 +617,43 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
- cursor_ = self._conn.cursor()
- item_id = cursor
- query = """--sql
- SELECT *
- FROM session_queue
- WHERE queue_id = ?
- """
- params: list[Union[str, int]] = [queue_id]
-
- if status is not None:
- query += """--sql
- AND status = ?
- """
- params.append(status)
-
- if destination is not None:
- query += """---sql
- AND destination = ?
+ with self._db.conn() as conn:
+ cursor_ = conn.cursor()
+ item_id = cursor
+ query = """--sql
+ SELECT *
+ FROM session_queue
+ WHERE queue_id = ?
"""
- params.append(destination)
+ params: list[Union[str, int]] = [queue_id]
- if item_id is not None:
- query += """--sql
- AND (priority < ?) OR (priority = ? AND item_id > ?)
+ if status is not None:
+ query += """--sql
+ AND status = ?
+ """
+ params.append(status)
+
+ if destination is not None:
+ query += """---sql
+ AND destination = ?
"""
- params.extend([priority, priority, item_id])
+ params.append(destination)
- query += """--sql
- ORDER BY
- priority DESC,
- item_id ASC
- LIMIT ?
- """
- params.append(limit + 1)
- cursor_.execute(query, params)
- results = cast(list[sqlite3.Row], cursor_.fetchall())
+ if item_id is not None:
+ query += """--sql
+ AND (priority < ?) OR (priority = ? AND item_id > ?)
+ """
+ params.extend([priority, priority, item_id])
+
+ query += """--sql
+ ORDER BY
+ priority DESC,
+ item_id ASC
+ LIMIT ?
+ """
+ params.append(limit + 1)
+ cursor_.execute(query, params)
+ results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
has_more = False
if len(items) > limit:
@@ -704,43 +668,45 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
- cursor_ = self._conn.cursor()
- query = """--sql
- SELECT *
- FROM session_queue
- WHERE queue_id = ?
- """
- params: list[Union[str, int]] = [queue_id]
-
- if destination is not None:
- query += """---sql
- AND destination = ?
+ with self._db.conn() as conn:
+ cursor_ = conn.cursor()
+ query = """--sql
+ SELECT *
+ FROM session_queue
+ WHERE queue_id = ?
"""
- params.append(destination)
+ params: list[Union[str, int]] = [queue_id]
- query += """--sql
- ORDER BY
- priority DESC,
- item_id ASC
- ;
- """
- cursor_.execute(query, params)
- results = cast(list[sqlite3.Row], cursor_.fetchall())
+ if destination is not None:
+ query += """---sql
+ AND destination = ?
+ """
+ params.append(destination)
+
+ query += """--sql
+ ORDER BY
+ priority DESC,
+ item_id ASC
+ ;
+ """
+ cursor_.execute(query, params)
+ results = cast(list[sqlite3.Row], cursor_.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT status, count(*)
- FROM session_queue
- WHERE queue_id = ?
- GROUP BY status
- """,
- (queue_id,),
- )
- counts_result = cast(list[sqlite3.Row], cursor.fetchall())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT status, count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ GROUP BY status
+ """,
+ (queue_id,),
+ )
+ counts_result = cast(list[sqlite3.Row], cursor.fetchall())
current_item = self.get_current(queue_id=queue_id)
total = sum(row[1] or 0 for row in counts_result)
@@ -759,19 +725,20 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT status, count(*), origin, destination
- FROM session_queue
- WHERE
- queue_id = ?
- AND batch_id = ?
- GROUP BY status
- """,
- (queue_id, batch_id),
- )
- result = cast(list[sqlite3.Row], cursor.fetchall())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT status, count(*), origin, destination
+ FROM session_queue
+ WHERE
+ queue_id = ?
+ AND batch_id = ?
+ GROUP BY status
+ """,
+ (queue_id, batch_id),
+ )
+ result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in result)
counts: dict[str, int] = {row[0]: row[1] for row in result}
origin = result[0]["origin"] if result else None
@@ -791,18 +758,19 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT status, count(*)
- FROM session_queue
- WHERE queue_id = ?
- AND destination = ?
- GROUP BY status
- """,
- (queue_id, destination),
- )
- counts_result = cast(list[sqlite3.Row], cursor.fetchall())
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT status, count(*)
+ FROM session_queue
+ WHERE queue_id = ?
+ AND destination = ?
+ GROUP BY status
+ """,
+ (queue_id, destination),
+ )
+ counts_result = cast(list[sqlite3.Row], cursor.fetchall())
total = sum(row[1] or 0 for row in counts_result)
counts: dict[str, int] = {row[0]: row[1] for row in counts_result}
@@ -820,8 +788,8 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
@@ -872,10 +840,6 @@ class SqliteSessionQueue(SessionQueueBase):
values_to_insert,
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
retry_result = RetryItemsResult(
queue_id=queue_id,
retried_item_ids=retried_item_ids,
diff --git a/invokeai/app/services/shared/sqlite/sqlite_database.py b/invokeai/app/services/shared/sqlite/sqlite_database.py
index e1895a41da..edfadd8142 100644
--- a/invokeai/app/services/shared/sqlite/sqlite_database.py
+++ b/invokeai/app/services/shared/sqlite/sqlite_database.py
@@ -1,4 +1,7 @@
import sqlite3
+import threading
+from collections.abc import Generator
+from contextlib import contextmanager
from logging import Logger
from pathlib import Path
@@ -26,46 +29,64 @@ class SqliteDatabase:
def __init__(self, db_path: Path | None, logger: Logger, verbose: bool = False) -> None:
"""Initializes the database. This is used internally by the class constructor."""
- self.logger = logger
- self.db_path = db_path
- self.verbose = verbose
+ self._logger = logger
+ self._db_path = db_path
+ self._verbose = verbose
+ self._lock = threading.RLock()
- if not self.db_path:
+ if not self._db_path:
logger.info("Initializing in-memory database")
else:
- self.db_path.parent.mkdir(parents=True, exist_ok=True)
- self.logger.info(f"Initializing database at {self.db_path}")
+ self._db_path.parent.mkdir(parents=True, exist_ok=True)
+ self._logger.info(f"Initializing database at {self._db_path}")
- self.conn = sqlite3.connect(database=self.db_path or sqlite_memory, check_same_thread=False)
- self.conn.row_factory = sqlite3.Row
+ self._conn = sqlite3.connect(database=self._db_path or sqlite_memory, check_same_thread=False)
+ self._conn.row_factory = sqlite3.Row
- if self.verbose:
- self.conn.set_trace_callback(self.logger.debug)
+ if self._verbose:
+ self._conn.set_trace_callback(self._logger.debug)
# Enable foreign key constraints
- self.conn.execute("PRAGMA foreign_keys = ON;")
+ self._conn.execute("PRAGMA foreign_keys = ON;")
# Enable Write-Ahead Logging (WAL) mode for better concurrency
- self.conn.execute("PRAGMA journal_mode = WAL;")
+ self._conn.execute("PRAGMA journal_mode = WAL;")
# Set a busy timeout to prevent database lockups during writes
- self.conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
+ self._conn.execute("PRAGMA busy_timeout = 5000;") # 5 seconds
def clean(self) -> None:
"""
Cleans the database by running the VACUUM command, reporting on the freed space.
"""
# No need to clean in-memory database
- if not self.db_path:
+ if not self._db_path:
return
try:
- initial_db_size = Path(self.db_path).stat().st_size
- self.conn.execute("VACUUM;")
- self.conn.commit()
- final_db_size = Path(self.db_path).stat().st_size
- freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
- if freed_space_in_mb > 0:
- self.logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
+ with self.conn() as conn:
+ initial_db_size = Path(self._db_path).stat().st_size
+ conn.execute("VACUUM;")
+ conn.commit()
+ final_db_size = Path(self._db_path).stat().st_size
+ freed_space_in_mb = round((initial_db_size - final_db_size) / 1024 / 1024, 2)
+ if freed_space_in_mb > 0:
+ self._logger.info(f"Cleaned database (freed {freed_space_in_mb}MB)")
except Exception as e:
- self.logger.error(f"Error cleaning database: {e}")
+ self._logger.error(f"Error cleaning database: {e}")
raise
+
+ @contextmanager
+ def conn(self) -> Generator[sqlite3.Connection]:
+ """
+ Thread-safe context manager for DB work.
+ Acquires the RLock, yields the Connection, then commits or rolls back.
+ """
+ self._lock.acquire()
+ try:
+ yield self._conn
+ self._conn.commit()
+ except:
+ self._conn.rollback()
+ raise
+ finally:
+ self._lock.release()
diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py
index 1a5798ac79..310abf0520 100644
--- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py
+++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_impl.py
@@ -32,7 +32,7 @@ class SqliteMigrator:
def __init__(self, db: SqliteDatabase) -> None:
self._db = db
- self._logger = db.logger
+ self._logger = db._logger
self._migration_set = MigrationSet()
self._backup_path: Optional[Path] = None
@@ -45,7 +45,7 @@ class SqliteMigrator:
"""Migrates the database to the latest version."""
# This throws if there is a problem.
self._migration_set.validate_migration_chain()
- cursor = self._db.conn.cursor()
+ cursor = self._db._conn.cursor()
self._create_migrations_table(cursor=cursor)
if self._migration_set.count == 0:
@@ -59,13 +59,13 @@ class SqliteMigrator:
self._logger.info("Database update needed")
# Make a backup of the db if it needs to be updated and is a file db
- if self._db.db_path is not None:
+ if self._db._db_path is not None:
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
- self._backup_path = self._db.db_path.parent / f"{self._db.db_path.stem}_backup_{timestamp}.db"
+ self._backup_path = self._db._db_path.parent / f"{self._db._db_path.stem}_backup_{timestamp}.db"
self._logger.info(f"Backing up database to {str(self._backup_path)}")
# Use SQLite to do the backup
with closing(sqlite3.connect(self._backup_path)) as backup_conn:
- self._db.conn.backup(backup_conn)
+ self._db._conn.backup(backup_conn)
else:
self._logger.info("Using in-memory database, no backup needed")
@@ -81,7 +81,7 @@ class SqliteMigrator:
try:
# Using sqlite3.Connection as a context manager commits a the transaction on exit, or rolls it back if an
# exception is raised.
- with self._db.conn as conn:
+ with self._db._conn as conn:
cursor = conn.cursor()
if self._get_current_version(cursor) != migration.from_version:
raise MigrationError(
diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
index 33805e5ee2..22e3255de1 100644
--- a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
+++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
@@ -17,7 +17,7 @@ from invokeai.app.util.misc import uuid_string
class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -25,24 +25,25 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT *
- FROM style_presets
- WHERE id = ?;
- """,
- (style_preset_id,),
- )
- row = cursor.fetchone()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT *
+ FROM style_presets
+ WHERE id = ?;
+ """,
+ (style_preset_id,),
+ )
+ row = cursor.fetchone()
if row is None:
raise StylePresetNotFoundError(f"Style preset with id {style_preset_id} not found")
return StylePresetRecordDTO.from_dict(dict(row))
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -60,16 +61,12 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return self.get(style_preset_id)
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -90,16 +87,12 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
style_preset.type,
),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -122,15 +115,11 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
(changes.preset_data.model_dump_json(), style_preset_id),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE from style_presets
@@ -138,51 +127,44 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
""",
(style_preset_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
- main_query = """
- SELECT
- *
- FROM style_presets
- """
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
- if type is not None:
- main_query += "WHERE type = ? "
+ main_query = """
+ SELECT
+ *
+ FROM style_presets
+ """
- main_query += "ORDER BY LOWER(name) ASC"
+ if type is not None:
+ main_query += "WHERE type = ? "
- cursor = self._conn.cursor()
- if type is not None:
- cursor.execute(main_query, (type,))
- else:
- cursor.execute(main_query)
+ main_query += "ORDER BY LOWER(name) ASC"
- rows = cursor.fetchall()
+ if type is not None:
+ cursor.execute(main_query, (type,))
+ else:
+ cursor.execute(main_query)
+
+ rows = cursor.fetchall()
style_presets = [StylePresetRecordDTO.from_dict(dict(row)) for row in rows]
return style_presets
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
-
- # First delete all existing default style presets
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ # First delete all existing default style presets
+ cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets
WHERE type = "default";
"""
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
# Next, parse and create the default style presets
with open(Path(__file__).parent / Path("default_style_presets.json"), "r") as file:
presets = json.load(file)
diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
index b84b226d9f..23547c083e 100644
--- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py
+++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
@@ -25,7 +25,7 @@ SQL_TIME_FORMAT = "%Y-%m-%d %H:%M:%f"
class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def __init__(self, db: SqliteDatabase) -> None:
super().__init__()
- self._conn = db.conn
+ self._db = db
def start(self, invoker: Invoker) -> None:
self._invoker = invoker
@@ -33,16 +33,17 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
- cursor = self._conn.cursor()
- cursor.execute(
- """--sql
- SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
- FROM workflow_library
- WHERE workflow_id = ?;
- """,
- (workflow_id,),
- )
- row = cursor.fetchone()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ cursor.execute(
+ """--sql
+ SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
+ FROM workflow_library
+ WHERE workflow_id = ?;
+ """,
+ (workflow_id,),
+ )
+ row = cursor.fetchone()
if row is None:
raise WorkflowNotFoundError(f"Workflow with id {workflow_id} not found")
return WorkflowRecordDTO.from_dict(dict(row))
@@ -51,9 +52,10 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
- try:
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
- cursor = self._conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO workflow_library (
@@ -64,18 +66,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_with_id.id, workflow_with_id.model_dump_json()),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return self.get(workflow_with_id.id)
def update(self, workflow: Workflow) -> WorkflowRecordDTO:
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -84,18 +82,14 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow.model_dump_json(), workflow.id),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return self.get(workflow.id)
def delete(self, workflow_id: str) -> None:
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -103,10 +97,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
return None
def get_many(
@@ -121,108 +111,109 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
- # sanitize!
- assert order_by in WorkflowRecordOrderBy
- assert direction in SQLiteDirection
+ with self._db.conn() as conn:
+ # sanitize!
+ assert order_by in WorkflowRecordOrderBy
+ assert direction in SQLiteDirection
- # We will construct the query dynamically based on the query params
+ # We will construct the query dynamically based on the query params
- # The main query to get the workflows / counts
- main_query = """
- SELECT
- workflow_id,
- category,
- name,
- description,
- created_at,
- updated_at,
- opened_at,
- tags
- FROM workflow_library
- """
- count_query = "SELECT COUNT(*) FROM workflow_library"
+ # The main query to get the workflows / counts
+ main_query = """
+ SELECT
+ workflow_id,
+ category,
+ name,
+ description,
+ created_at,
+ updated_at,
+ opened_at,
+ tags
+ FROM workflow_library
+ """
+ count_query = "SELECT COUNT(*) FROM workflow_library"
- # Start with an empty list of conditions and params
- conditions: list[str] = []
- params: list[str | int] = []
+ # Start with an empty list of conditions and params
+ conditions: list[str] = []
+ params: list[str | int] = []
- if categories:
- # Categories is a list of WorkflowCategory enum values, and a single string in the DB
+ if categories:
+ # Categories is a list of WorkflowCategory enum values, and a single string in the DB
- # Ensure all categories are valid (is this necessary?)
- assert all(c in WorkflowCategory for c in categories)
+ # Ensure all categories are valid (is this necessary?)
+ assert all(c in WorkflowCategory for c in categories)
- # Construct a placeholder string for the number of categories
- placeholders = ", ".join("?" for _ in categories)
+ # Construct a placeholder string for the number of categories
+ placeholders = ", ".join("?" for _ in categories)
- # Construct the condition string & params
- category_condition = f"category IN ({placeholders})"
- category_params = [category.value for category in categories]
+ # Construct the condition string & params
+ category_condition = f"category IN ({placeholders})"
+ category_params = [category.value for category in categories]
- conditions.append(category_condition)
- params.extend(category_params)
+ conditions.append(category_condition)
+ params.extend(category_params)
- if tags:
- # Tags is a list of strings, and a single string in the DB
- # The string in the DB has no guaranteed format
+ if tags:
+ # Tags is a list of strings, and a single string in the DB
+ # The string in the DB has no guaranteed format
- # Construct a list of conditions for each tag
- tags_conditions = ["tags LIKE ?" for _ in tags]
- tags_conditions_joined = " OR ".join(tags_conditions)
- tags_condition = f"({tags_conditions_joined})"
+ # Construct a list of conditions for each tag
+ tags_conditions = ["tags LIKE ?" for _ in tags]
+ tags_conditions_joined = " OR ".join(tags_conditions)
+ tags_condition = f"({tags_conditions_joined})"
- # And the params for the tags, case-insensitive
- tags_params = [f"%{t.strip()}%" for t in tags]
+ # And the params for the tags, case-insensitive
+ tags_params = [f"%{t.strip()}%" for t in tags]
- conditions.append(tags_condition)
- params.extend(tags_params)
+ conditions.append(tags_condition)
+ params.extend(tags_params)
- if has_been_opened:
- conditions.append("opened_at IS NOT NULL")
- elif has_been_opened is False:
- conditions.append("opened_at IS NULL")
+ if has_been_opened:
+ conditions.append("opened_at IS NOT NULL")
+ elif has_been_opened is False:
+ conditions.append("opened_at IS NULL")
- # Ignore whitespace in the query
- stripped_query = query.strip() if query else None
- if stripped_query:
- # Construct a wildcard query for the name, description, and tags
- wildcard_query = "%" + stripped_query + "%"
- query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
+ # Ignore whitespace in the query
+ stripped_query = query.strip() if query else None
+ if stripped_query:
+ # Construct a wildcard query for the name, description, and tags
+ wildcard_query = "%" + stripped_query + "%"
+ query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
- conditions.append(query_condition)
- params.extend([wildcard_query, wildcard_query, wildcard_query])
+ conditions.append(query_condition)
+ params.extend([wildcard_query, wildcard_query, wildcard_query])
- if conditions:
- # If there are conditions, add a WHERE clause and then join the conditions
- main_query += " WHERE "
- count_query += " WHERE "
+ if conditions:
+ # If there are conditions, add a WHERE clause and then join the conditions
+ main_query += " WHERE "
+ count_query += " WHERE "
- all_conditions = " AND ".join(conditions)
- main_query += all_conditions
- count_query += all_conditions
+ all_conditions = " AND ".join(conditions)
+ main_query += all_conditions
+ count_query += all_conditions
- # After this point, the query and params differ for the main query and the count query
- main_params = params.copy()
- count_params = params.copy()
+ # After this point, the query and params differ for the main query and the count query
+ main_params = params.copy()
+ count_params = params.copy()
- # Main query also gets ORDER BY and LIMIT/OFFSET
- main_query += f" ORDER BY {order_by.value} {direction.value}"
+ # Main query also gets ORDER BY and LIMIT/OFFSET
+ main_query += f" ORDER BY {order_by.value} {direction.value}"
- if per_page:
- main_query += " LIMIT ? OFFSET ?"
- main_params.extend([per_page, page * per_page])
+ if per_page:
+ main_query += " LIMIT ? OFFSET ?"
+ main_params.extend([per_page, page * per_page])
- # Put a ring on it
- main_query += ";"
- count_query += ";"
+ # Put a ring on it
+ main_query += ";"
+ count_query += ";"
- cursor = self._conn.cursor()
- cursor.execute(main_query, main_params)
- rows = cursor.fetchall()
- workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
+ cursor = conn.cursor()
+ cursor.execute(main_query, main_params)
+ rows = cursor.fetchall()
+ workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
- cursor.execute(count_query, count_params)
- total = cursor.fetchone()[0]
+ cursor.execute(count_query, count_params)
+ total = cursor.fetchone()[0]
if per_page:
pages = total // per_page + (total % per_page > 0)
@@ -247,46 +238,47 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
- cursor = self._conn.cursor()
- result: dict[str, int] = {}
- # Base conditions for categories and selected tags
- base_conditions: list[str] = []
- base_params: list[str | int] = []
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ result: dict[str, int] = {}
+ # Base conditions for categories and selected tags
+ base_conditions: list[str] = []
+ base_params: list[str | int] = []
- # Add category conditions
- if categories:
- assert all(c in WorkflowCategory for c in categories)
- placeholders = ", ".join("?" for _ in categories)
- base_conditions.append(f"category IN ({placeholders})")
- base_params.extend([category.value for category in categories])
+ # Add category conditions
+ if categories:
+ assert all(c in WorkflowCategory for c in categories)
+ placeholders = ", ".join("?" for _ in categories)
+ base_conditions.append(f"category IN ({placeholders})")
+ base_params.extend([category.value for category in categories])
- if has_been_opened:
- base_conditions.append("opened_at IS NOT NULL")
- elif has_been_opened is False:
- base_conditions.append("opened_at IS NULL")
+ if has_been_opened:
+ base_conditions.append("opened_at IS NOT NULL")
+ elif has_been_opened is False:
+ base_conditions.append("opened_at IS NULL")
- # For each tag to count, run a separate query
- for tag in tags:
- # Start with the base conditions
- conditions = base_conditions.copy()
- params = base_params.copy()
+ # For each tag to count, run a separate query
+ for tag in tags:
+ # Start with the base conditions
+ conditions = base_conditions.copy()
+ params = base_params.copy()
- # Add this specific tag condition
- conditions.append("tags LIKE ?")
- params.append(f"%{tag.strip()}%")
+ # Add this specific tag condition
+ conditions.append("tags LIKE ?")
+ params.append(f"%{tag.strip()}%")
- # Construct the full query
- stmt = """--sql
- SELECT COUNT(*)
- FROM workflow_library
- """
+ # Construct the full query
+ stmt = """--sql
+ SELECT COUNT(*)
+ FROM workflow_library
+ """
- if conditions:
- stmt += " WHERE " + " AND ".join(conditions)
+ if conditions:
+ stmt += " WHERE " + " AND ".join(conditions)
- cursor.execute(stmt, params)
- count = cursor.fetchone()[0]
- result[tag] = count
+ cursor.execute(stmt, params)
+ count = cursor.fetchone()[0]
+ result[tag] = count
return result
@@ -296,52 +288,53 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
- cursor = self._conn.cursor()
- result: dict[str, int] = {}
- # Base conditions for categories
- base_conditions: list[str] = []
- base_params: list[str | int] = []
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
+ result: dict[str, int] = {}
+ # Base conditions for categories
+ base_conditions: list[str] = []
+ base_params: list[str | int] = []
- # Add category conditions
- if categories:
- assert all(c in WorkflowCategory for c in categories)
- placeholders = ", ".join("?" for _ in categories)
- base_conditions.append(f"category IN ({placeholders})")
- base_params.extend([category.value for category in categories])
+ # Add category conditions
+ if categories:
+ assert all(c in WorkflowCategory for c in categories)
+ placeholders = ", ".join("?" for _ in categories)
+ base_conditions.append(f"category IN ({placeholders})")
+ base_params.extend([category.value for category in categories])
- if has_been_opened:
- base_conditions.append("opened_at IS NOT NULL")
- elif has_been_opened is False:
- base_conditions.append("opened_at IS NULL")
+ if has_been_opened:
+ base_conditions.append("opened_at IS NOT NULL")
+ elif has_been_opened is False:
+ base_conditions.append("opened_at IS NULL")
- # For each category to count, run a separate query
- for category in categories:
- # Start with the base conditions
- conditions = base_conditions.copy()
- params = base_params.copy()
+ # For each category to count, run a separate query
+ for category in categories:
+ # Start with the base conditions
+ conditions = base_conditions.copy()
+ params = base_params.copy()
- # Add this specific category condition
- conditions.append("category = ?")
- params.append(category.value)
+ # Add this specific category condition
+ conditions.append("category = ?")
+ params.append(category.value)
- # Construct the full query
- stmt = """--sql
- SELECT COUNT(*)
- FROM workflow_library
- """
+ # Construct the full query
+ stmt = """--sql
+ SELECT COUNT(*)
+ FROM workflow_library
+ """
- if conditions:
- stmt += " WHERE " + " AND ".join(conditions)
+ if conditions:
+ stmt += " WHERE " + " AND ".join(conditions)
- cursor.execute(stmt, params)
- count = cursor.fetchone()[0]
- result[category.value] = count
+ cursor.execute(stmt, params)
+ count = cursor.fetchone()[0]
+ result[category.value] = count
return result
def update_opened_at(self, workflow_id: str) -> None:
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -350,10 +343,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(workflow_id,),
)
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
def _sync_default_workflows(self) -> None:
"""Syncs default workflows to the database. Internal use only."""
@@ -368,8 +357,8 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
- try:
- cursor = self._conn.cursor()
+ with self._db.conn() as conn:
+ cursor = conn.cursor()
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
@@ -449,8 +438,3 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
""",
(w.model_dump_json(), w.id),
)
-
- self._conn.commit()
- except Exception:
- self._conn.rollback()
- raise
diff --git a/tests/test_sqlite_migrator.py b/tests/test_sqlite_migrator.py
index 7f72d0bd13..f6a3cb2a5a 100644
--- a/tests/test_sqlite_migrator.py
+++ b/tests/test_sqlite_migrator.py
@@ -191,14 +191,14 @@ def test_migrator_registers_migration(migrator: SqliteMigrator, migration_no_op:
def test_migrator_creates_migrations_table(migrator: SqliteMigrator) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrator._create_migrations_table(cursor)
cursor.execute("SELECT * FROM sqlite_master WHERE type='table' AND name='migrations';")
assert cursor.fetchone() is not None
def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrator._create_migrations_table(cursor)
migrator.register_migration(migration_no_op)
migrator.run_migrations()
@@ -207,7 +207,7 @@ def test_migrator_migration_sets_version(migrator: SqliteMigrator, migration_no_
def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op: Migration) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
assert migrator._get_current_version(cursor) == 0
migrator._create_migrations_table(cursor)
assert migrator._get_current_version(cursor) == 0
@@ -217,7 +217,7 @@ def test_migrator_gets_current_version(migrator: SqliteMigrator, migration_no_op
def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrator._create_migrations_table(cursor)
migrator._run_migration(migration_create_test_table)
assert migrator._get_current_version(cursor) == 1
@@ -226,7 +226,7 @@ def test_migrator_runs_single_migration(migrator: SqliteMigrator, migration_crea
def test_migrator_runs_all_migrations_in_memory(migrator: SqliteMigrator) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
for migration in migrations:
migrator.register_migration(migration)
@@ -247,7 +247,7 @@ def test_migrator_runs_all_migrations_file(logger: Logger) -> None:
original_db_cursor = original_db_conn.cursor()
assert SqliteMigrator._get_current_version(original_db_cursor) == 3
# Must manually close else we get an error on Windows
- db.conn.close()
+ db._conn.close()
def test_migrator_backs_up_db(logger: Logger) -> None:
@@ -255,9 +255,9 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
original_db_path = Path(tempdir) / "invokeai.db"
db = SqliteDatabase(db_path=original_db_path, logger=logger, verbose=False)
# Write some data to the db to test for successful backup
- temp_cursor = db.conn.cursor()
+ temp_cursor = db._conn.cursor()
temp_cursor.execute("CREATE TABLE test (id INTEGER PRIMARY KEY);")
- db.conn.commit()
+ db._conn.commit()
# Set up the migrator
migrator = SqliteMigrator(db=db)
migrations = [Migration(from_version=i, to_version=i + 1, callback=create_migrate(i)) for i in range(0, 3)]
@@ -265,7 +265,7 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
migrator.register_migration(migration)
migrator.run_migrations()
# Must manually close else we get an error on Windows
- db.conn.close()
+ db._conn.close()
assert original_db_path.exists()
# We should have a backup file when we migrated a file db
assert migrator._backup_path
@@ -279,7 +279,7 @@ def test_migrator_backs_up_db(logger: Logger) -> None:
def test_migrator_makes_no_changes_on_failed_migration(
migrator: SqliteMigrator, migration_no_op: Migration, failing_migrate_callback: MigrateCallback
) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrator.register_migration(migration_no_op)
migrator.run_migrations()
assert migrator._get_current_version(cursor) == 1
@@ -290,7 +290,7 @@ def test_migrator_makes_no_changes_on_failed_migration(
def test_idempotent_migrations(migrator: SqliteMigrator, migration_create_test_table: Migration) -> None:
- cursor = migrator._db.conn.cursor()
+ cursor = migrator._db._conn.cursor()
migrator.register_migration(migration_create_test_table)
migrator.run_migrations()
# not throwing is sufficient
From fc71849c245035c9ba5d217f57d2570caf7ca2bb Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 08:12:26 +1000
Subject: [PATCH 03/17] feat(app): expose a cursor, not a connection in db util
---
.../board_image_records_sqlite.py | 20 ++---
.../board_records/board_records_sqlite.py | 19 ++--
.../image_records/image_records_sqlite.py | 49 ++++------
.../model_records/model_records_sql.py | 31 +++----
.../model_relationship_records_sqlite.py | 13 +--
.../session_queue/session_queue_sqlite.py | 89 +++++++------------
.../services/shared/sqlite/sqlite_database.py | 25 +++---
.../style_preset_records_sqlite.py | 22 ++---
.../workflow_records_sqlite.py | 28 ++----
9 files changed, 106 insertions(+), 190 deletions(-)
diff --git a/invokeai/app/services/board_image_records/board_image_records_sqlite.py b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
index a6c178097e..0f914d0343 100644
--- a/invokeai/app/services/board_image_records/board_image_records_sqlite.py
+++ b/invokeai/app/services/board_image_records/board_image_records_sqlite.py
@@ -21,8 +21,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
board_id: str,
image_name: str,
) -> None:
- with self._db.conn() as conn:
- conn.execute(
+ with self._db.transaction() as cursor:
+ cursor.execute(
"""--sql
INSERT INTO board_images (board_id, image_name)
VALUES (?, ?)
@@ -35,8 +35,8 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> None:
- with self._db.conn() as conn:
- conn.execute(
+ with self._db.transaction() as cursor:
+ cursor.execute(
"""--sql
DELETE FROM board_images
WHERE image_name = ?;
@@ -50,8 +50,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
offset: int = 0,
limit: int = 10,
) -> OffsetPaginatedResults[ImageRecord]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
@@ -80,8 +79,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
categories: list[ImageCategory] | None,
is_intermediate: bool | None,
) -> list[str]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
params: list[str | bool] = []
# Base query is a join between images and board_images
@@ -137,8 +135,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
self,
image_name: str,
) -> Optional[str]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT board_id
@@ -153,8 +150,7 @@ class SqliteBoardImageRecordStorage(BoardImageRecordStorageBase):
return cast(str, result[0])
def get_image_count_for_board(self, board_id: str) -> int:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
diff --git a/invokeai/app/services/board_records/board_records_sqlite.py b/invokeai/app/services/board_records/board_records_sqlite.py
index 696ffab4b2..45fe33c540 100644
--- a/invokeai/app/services/board_records/board_records_sqlite.py
+++ b/invokeai/app/services/board_records/board_records_sqlite.py
@@ -23,9 +23,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self._db = db
def delete(self, board_id: str) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM boards
@@ -40,10 +39,9 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self,
board_name: str,
) -> BoardRecord:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
board_id = uuid_string()
- cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO boards (board_id, board_name)
@@ -59,9 +57,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
self,
board_id: str,
) -> BoardRecord:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT *
@@ -83,9 +80,8 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
board_id: str,
changes: BoardChanges,
) -> BoardRecord:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
# Change the name of a board
if changes.board_name is not None:
cursor.execute(
@@ -131,9 +127,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
limit: int = 10,
include_archived: bool = False,
) -> OffsetPaginatedResults[BoardRecord]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
# Build base query
base_query = """
SELECT *
@@ -179,8 +173,7 @@ class SqliteBoardRecordStorage(BoardRecordStorageBase):
def get_all(
self, order_by: BoardRecordOrderBy, direction: SQLiteDirection, include_archived: bool = False
) -> list[BoardRecord]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
if order_by == BoardRecordOrderBy.Name:
base_query = """
SELECT *
diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py
index 78f725ba17..cb968e76bb 100644
--- a/invokeai/app/services/image_records/image_records_sqlite.py
+++ b/invokeai/app/services/image_records/image_records_sqlite.py
@@ -27,9 +27,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
self._db = db
def get(self, image_name: str) -> ImageRecord:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
f"""--sql
SELECT {IMAGE_DTO_COLS} FROM images
@@ -48,9 +47,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return deserialize_image_record(dict(result))
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT metadata FROM images
@@ -76,9 +74,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
image_name: str,
changes: ImageRecordChanges,
) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
# Change the category of the image
if changes.image_category is not None:
cursor.execute(
@@ -138,9 +135,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> OffsetPaginatedResults[ImageRecord]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
# Manually build two queries - one for the count, one for the records
count_query = """--sql
SELECT COUNT(*)
@@ -227,20 +222,20 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
# Build the list of images, deserializing each row
cursor.execute(images_query, images_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
- images = [deserialize_image_record(dict(r)) for r in result]
- # Set up and execute the count query, without pagination
- count_query += query_conditions + ";"
- count_params = query_params.copy()
- cursor.execute(count_query, count_params)
- count = cast(int, cursor.fetchone()[0])
+ images = [deserialize_image_record(dict(r)) for r in result]
+
+ # Set up and execute the count query, without pagination
+ count_query += query_conditions + ";"
+ count_params = query_params.copy()
+ cursor.execute(count_query, count_params)
+ count = cast(int, cursor.fetchone()[0])
return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count)
def delete(self, image_name: str) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM images
@@ -252,10 +247,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordDeleteException from e
def delete_many(self, image_names: list[str]) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
-
placeholders = ",".join("?" for _ in image_names)
# Construct the SQLite query with the placeholders
@@ -268,8 +261,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
raise ImageRecordDeleteException from e
def get_intermediates_count(self) -> int:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*) FROM images
@@ -280,9 +272,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return count
def delete_intermediates(self) -> list[str]:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
SELECT image_name FROM images
@@ -315,9 +306,8 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
node_id: Optional[str] = None,
metadata: Optional[str] = None,
) -> datetime:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT OR IGNORE INTO images (
@@ -366,8 +356,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
return created_at
def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT images.*
@@ -398,9 +387,7 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> ImageNamesResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
diff --git a/invokeai/app/services/model_records/model_records_sql.py b/invokeai/app/services/model_records/model_records_sql.py
index fa8b7c0ff1..e3b24a6e62 100644
--- a/invokeai/app/services/model_records/model_records_sql.py
+++ b/invokeai/app/services/model_records/model_records_sql.py
@@ -88,9 +88,8 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise DuplicateModelException and InvalidModelConfigException exceptions.
"""
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
try:
- cursor = conn.cursor()
cursor.execute(
"""--sql
INSERT INTO models (
@@ -127,8 +126,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Can raise an UnknownModelException
"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE FROM models
@@ -140,7 +138,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
raise UnknownModelException("model not found")
def update_model(self, key: str, changes: ModelRecordChanges) -> AnyModelConfig:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
record = self.get_model(key)
# Model configs use pydantic's `validate_assignment`, so each change is validated by pydantic.
@@ -149,7 +147,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
json_serialized = record.model_dump_json()
- cursor = conn.cursor()
cursor.execute(
"""--sql
UPDATE models
@@ -172,8 +169,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
Exceptions: UnknownModelException
"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -188,8 +184,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
return model
def get_model_by_hash(self, hash: str) -> AnyModelConfig:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -209,8 +204,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
:param key: Unique key for the model to be deleted
"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
select count(*) FROM models
@@ -241,7 +235,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
If none of the optional filters are passed, will return all
models in the database.
"""
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -267,7 +261,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
bindings.append(model_format)
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
- cursor = conn.cursor()
cursor.execute(
f"""--sql
SELECT config, strftime('%s',updated_at)
@@ -299,8 +292,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_path(self, path: Union[str, Path]) -> List[AnyModelConfig]:
"""Return models with the indicated path."""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -313,8 +305,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
def search_by_hash(self, hash: str) -> List[AnyModelConfig]:
"""Return models with the indicated hash."""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT config, strftime('%s',updated_at) FROM models
@@ -329,7 +320,7 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
self, page: int = 0, per_page: int = 10, order_by: ModelRecordOrderBy = ModelRecordOrderBy.Default
) -> PaginatedResults[ModelSummary]:
"""Return a paginated summary listing of each model in the database."""
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
assert isinstance(order_by, ModelRecordOrderBy)
ordering = {
ModelRecordOrderBy.Default: "type, base, name, format",
@@ -339,8 +330,6 @@ class ModelRecordServiceSQL(ModelRecordServiceBase):
ModelRecordOrderBy.Format: "format",
}
- cursor = conn.cursor()
-
# Lock so that the database isn't updated while we're doing the two queries.
# query1: get the total number of model configs
cursor.execute(
diff --git a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
index aa429351aa..c12990b8c3 100644
--- a/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
+++ b/invokeai/app/services/model_relationship_records/model_relationship_records_sqlite.py
@@ -10,28 +10,25 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
self._db = db
def add_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
if model_key_1 == model_key_2:
raise ValueError("Cannot relate a model to itself.")
a, b = sorted([model_key_1, model_key_2])
- cursor = conn.cursor()
cursor.execute(
"INSERT OR IGNORE INTO model_relationships (model_key_1, model_key_2) VALUES (?, ?)",
(a, b),
)
def remove_model_relationship(self, model_key_1: str, model_key_2: str) -> None:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
a, b = sorted([model_key_1, model_key_2])
- cursor = conn.cursor()
cursor.execute(
"DELETE FROM model_relationships WHERE model_key_1 = ? AND model_key_2 = ?",
(a, b),
)
def get_related_model_keys(self, model_key: str) -> list[str]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""
SELECT model_key_2 FROM model_relationships WHERE model_key_1 = ?
@@ -44,9 +41,7 @@ class SqliteModelRelationshipRecordStorage(ModelRelationshipRecordStorageBase):
return result
def get_related_model_keys_batch(self, model_keys: list[str]) -> list[str]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
key_list = ",".join("?" for _ in model_keys)
cursor.execute(
f"""
diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py
index 491ec6c20c..2e450399bc 100644
--- a/invokeai/app/services/session_queue/session_queue_sqlite.py
+++ b/invokeai/app/services/session_queue/session_queue_sqlite.py
@@ -57,8 +57,8 @@ class SqliteSessionQueue(SessionQueueBase):
Sets all in_progress queue items to canceled. Run on app startup, not associated with any queue.
This is necessary because the invoker may have been killed while processing a queue item.
"""
- with self._db.conn() as conn:
- conn.execute(
+ with self._db.transaction() as cursor:
+ cursor.execute(
"""--sql
UPDATE session_queue
SET status = 'canceled'
@@ -68,8 +68,7 @@ class SqliteSessionQueue(SessionQueueBase):
def _get_current_queue_size(self, queue_id: str) -> int:
"""Gets the current number of pending queue items"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -85,8 +84,7 @@ class SqliteSessionQueue(SessionQueueBase):
def _get_highest_priority(self, queue_id: str) -> int:
"""Gets the highest priority value in the queue"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT MAX(priority)
@@ -122,8 +120,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
enqueued_count = len(values_to_insert)
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.executemany(
"""--sql
INSERT INTO session_queue (queue_id, session, session_id, batch_id, field_values, priority, workflow, origin, destination, retried_from_item_id)
@@ -153,8 +150,7 @@ class SqliteSessionQueue(SessionQueueBase):
return enqueue_result
def dequeue(self) -> Optional[SessionQueueItem]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -174,8 +170,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def get_next(self, queue_id: str) -> Optional[SessionQueueItem]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -196,8 +191,7 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def get_current(self, queue_id: str) -> Optional[SessionQueueItem]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -222,8 +216,7 @@ class SqliteSessionQueue(SessionQueueBase):
error_message: Optional[str] = None,
error_traceback: Optional[str] = None,
) -> SessionQueueItem:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status FROM session_queue WHERE item_id = ?
@@ -239,8 +232,7 @@ class SqliteSessionQueue(SessionQueueBase):
if current_status in ("completed", "failed", "canceled"):
return self.get_queue_item(item_id)
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE session_queue
@@ -257,8 +249,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def is_empty(self, queue_id: str) -> IsEmptyResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -271,8 +262,7 @@ class SqliteSessionQueue(SessionQueueBase):
return IsEmptyResult(is_empty=is_empty)
def is_full(self, queue_id: str) -> IsFullResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT count(*)
@@ -286,8 +276,7 @@ class SqliteSessionQueue(SessionQueueBase):
return IsFullResult(is_full=is_full)
def clear(self, queue_id: str) -> ClearResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT COUNT(*)
@@ -309,8 +298,7 @@ class SqliteSessionQueue(SessionQueueBase):
return ClearResult(deleted=count)
def prune(self, queue_id: str) -> PruneResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id = ?
@@ -349,8 +337,7 @@ class SqliteSessionQueue(SessionQueueBase):
self.cancel_queue_item(item_id)
except SessionQueueItemNotFoundError:
pass
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE
@@ -381,8 +368,7 @@ class SqliteSessionQueue(SessionQueueBase):
return queue_item
def cancel_by_batch_ids(self, queue_id: str, batch_ids: list[str]) -> CancelByBatchIDsResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
placeholders = ", ".join(["?" for _ in batch_ids])
where = f"""--sql
@@ -420,8 +406,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByBatchIDsResult(canceled=count)
def cancel_by_destination(self, queue_id: str, destination: str) -> CancelByDestinationResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -456,8 +441,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByDestinationResult(canceled=count)
def delete_by_destination(self, queue_id: str, destination: str) -> DeleteByDestinationResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
if current_queue_item is not None and current_queue_item.destination == destination:
self.cancel_queue_item(current_queue_item.item_id)
@@ -486,8 +470,7 @@ class SqliteSessionQueue(SessionQueueBase):
return DeleteByDestinationResult(deleted=count)
def delete_all_except_current(self, queue_id: str) -> DeleteAllExceptCurrentResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -513,8 +496,7 @@ class SqliteSessionQueue(SessionQueueBase):
return DeleteAllExceptCurrentResult(deleted=count)
def cancel_by_queue_id(self, queue_id: str) -> CancelByQueueIDResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
current_queue_item = self.get_current(queue_id)
where = """--sql
WHERE
@@ -549,8 +531,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelByQueueIDResult(canceled=count)
def cancel_all_except_current(self, queue_id: str) -> CancelAllExceptCurrentResult:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
where = """--sql
WHERE
queue_id == ?
@@ -576,8 +557,7 @@ class SqliteSessionQueue(SessionQueueBase):
return CancelAllExceptCurrentResult(canceled=count)
def get_queue_item(self, item_id: int) -> SessionQueueItem:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT * FROM session_queue
@@ -592,8 +572,7 @@ class SqliteSessionQueue(SessionQueueBase):
return SessionQueueItem.queue_item_from_dict(dict(result))
def set_queue_item_session(self, item_id: int, session: GraphExecutionState) -> SessionQueueItem:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
# Use exclude_none so we don't end up with a bunch of nulls in the graph - this can cause validation errors
# when the graph is loaded. Graph execution occurs purely in memory - the session saved here is not referenced
# during execution.
@@ -617,8 +596,7 @@ class SqliteSessionQueue(SessionQueueBase):
status: Optional[QUEUE_ITEM_STATUS] = None,
destination: Optional[str] = None,
) -> CursorPaginatedResults[SessionQueueItem]:
- with self._db.conn() as conn:
- cursor_ = conn.cursor()
+ with self._db.transaction() as cursor_:
item_id = cursor
query = """--sql
SELECT *
@@ -668,8 +646,7 @@ class SqliteSessionQueue(SessionQueueBase):
destination: Optional[str] = None,
) -> list[SessionQueueItem]:
"""Gets all queue items that match the given parameters"""
- with self._db.conn() as conn:
- cursor_ = conn.cursor()
+ with self._db.transaction() as cursor:
query = """--sql
SELECT *
FROM session_queue
@@ -689,14 +666,13 @@ class SqliteSessionQueue(SessionQueueBase):
item_id ASC
;
"""
- cursor_.execute(query, params)
- results = cast(list[sqlite3.Row], cursor_.fetchall())
+ cursor.execute(query, params)
+ results = cast(list[sqlite3.Row], cursor.fetchall())
items = [SessionQueueItem.queue_item_from_dict(dict(result)) for result in results]
return items
def get_queue_status(self, queue_id: str) -> SessionQueueStatus:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
@@ -725,8 +701,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_batch_status(self, queue_id: str, batch_id: str) -> BatchStatus:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*), origin, destination
@@ -758,8 +733,7 @@ class SqliteSessionQueue(SessionQueueBase):
)
def get_counts_by_destination(self, queue_id: str, destination: str) -> SessionQueueCountsByDestination:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT status, count(*)
@@ -788,8 +762,7 @@ class SqliteSessionQueue(SessionQueueBase):
def retry_items_by_id(self, queue_id: str, item_ids: list[int]) -> RetryItemsResult:
"""Retries the given queue items"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
values_to_insert: list[ValueToInsertTuple] = []
retried_item_ids: list[int] = []
diff --git a/invokeai/app/services/shared/sqlite/sqlite_database.py b/invokeai/app/services/shared/sqlite/sqlite_database.py
index edfadd8142..d14d803970 100644
--- a/invokeai/app/services/shared/sqlite/sqlite_database.py
+++ b/invokeai/app/services/shared/sqlite/sqlite_database.py
@@ -63,7 +63,7 @@ class SqliteDatabase:
if not self._db_path:
return
try:
- with self.conn() as conn:
+ with self._conn as conn:
initial_db_size = Path(self._db_path).stat().st_size
conn.execute("VACUUM;")
conn.commit()
@@ -76,17 +76,18 @@ class SqliteDatabase:
raise
@contextmanager
- def conn(self) -> Generator[sqlite3.Connection]:
+ def transaction(self) -> Generator[sqlite3.Cursor, None, None]:
"""
Thread-safe context manager for DB work.
- Acquires the RLock, yields the Connection, then commits or rolls back.
+ Acquires the RLock, yields a Cursor, then commits or rolls back.
"""
- self._lock.acquire()
- try:
- yield self._conn
- self._conn.commit()
- except:
- self._conn.rollback()
- raise
- finally:
- self._lock.release()
+ with self._lock:
+ cursor = self._conn.cursor()
+ try:
+ yield cursor
+ self._conn.commit()
+ except:
+ self._conn.rollback()
+ raise
+ finally:
+ cursor.close()
diff --git a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
index 22e3255de1..35819fa0f0 100644
--- a/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
+++ b/invokeai/app/services/style_preset_records/style_preset_records_sqlite.py
@@ -25,8 +25,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def get(self, style_preset_id: str) -> StylePresetRecordDTO:
"""Gets a style preset by ID."""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT *
@@ -42,8 +41,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def create(self, style_preset: StylePresetWithoutId) -> StylePresetRecordDTO:
style_preset_id = uuid_string()
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
INSERT OR IGNORE INTO style_presets (
@@ -65,8 +63,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def create_many(self, style_presets: list[StylePresetWithoutId]) -> None:
style_preset_ids = []
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
for style_preset in style_presets:
style_preset_id = uuid_string()
style_preset_ids.append(style_preset_id)
@@ -91,8 +88,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return None
def update(self, style_preset_id: str, changes: StylePresetChanges) -> StylePresetRecordDTO:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
# Change the name of a style preset
if changes.name is not None:
cursor.execute(
@@ -118,8 +114,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return self.get(style_preset_id)
def delete(self, style_preset_id: str) -> None:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from style_presets
@@ -130,9 +125,7 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
return None
def get_many(self, type: PresetType | None = None) -> list[StylePresetRecordDTO]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
main_query = """
SELECT
*
@@ -156,9 +149,8 @@ class SqliteStylePresetRecordsStorage(StylePresetRecordsStorageBase):
def _sync_default_style_presets(self) -> None:
"""Syncs default style presets to the database. Internal use only."""
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
# First delete all existing default style presets
- cursor = conn.cursor()
cursor.execute(
"""--sql
DELETE FROM style_presets
diff --git a/invokeai/app/services/workflow_records/workflow_records_sqlite.py b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
index 23547c083e..72f37469de 100644
--- a/invokeai/app/services/workflow_records/workflow_records_sqlite.py
+++ b/invokeai/app/services/workflow_records/workflow_records_sqlite.py
@@ -33,8 +33,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
def get(self, workflow_id: str) -> WorkflowRecordDTO:
"""Gets a workflow by ID. Updates the opened_at column."""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
SELECT workflow_id, workflow, name, created_at, updated_at, opened_at
@@ -52,9 +51,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be created via this method")
- with self._db.conn() as conn:
- cursor = conn.cursor()
-
+ with self._db.transaction() as cursor:
workflow_with_id = Workflow(**workflow.model_dump(), id=uuid_string())
cursor.execute(
"""--sql
@@ -72,8 +69,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be updated")
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
UPDATE workflow_library
@@ -88,8 +84,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if self.get(workflow_id).workflow.meta.category is WorkflowCategory.Default:
raise ValueError("Default workflows cannot be deleted")
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
"""--sql
DELETE from workflow_library
@@ -111,7 +106,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> PaginatedResults[WorkflowRecordListItemDTO]:
- with self._db.conn() as conn:
+ with self._db.transaction() as cursor:
# sanitize!
assert order_by in WorkflowRecordOrderBy
assert direction in SQLiteDirection
@@ -207,7 +202,6 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
main_query += ";"
count_query += ";"
- cursor = conn.cursor()
cursor.execute(main_query, main_params)
rows = cursor.fetchall()
workflows = [WorkflowRecordListItemDTOValidator.validate_python(dict(row)) for row in rows]
@@ -238,8 +232,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
if not tags:
return {}
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories and selected tags
base_conditions: list[str] = []
@@ -288,8 +281,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
has_been_opened: Optional[bool] = None,
is_published: Optional[bool] = None,
) -> dict[str, int]:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
result: dict[str, int] = {}
# Base conditions for categories
base_conditions: list[str] = []
@@ -333,8 +325,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
return result
def update_opened_at(self, workflow_id: str) -> None:
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
cursor.execute(
f"""--sql
UPDATE workflow_library
@@ -357,8 +348,7 @@ class SqliteWorkflowRecordsStorage(WorkflowRecordsStorageBase):
meaningless, as they are overwritten every time the server starts.
"""
- with self._db.conn() as conn:
- cursor = conn.cursor()
+ with self._db.transaction() as cursor:
workflows_from_file: list[Workflow] = []
workflows_to_update: list[Workflow] = []
workflows_to_add: list[Workflow] = []
From ac981879efc4c7103b8bd74973fded2b7de8a4bd Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 07:38:08 +1000
Subject: [PATCH 04/17] fix(ui): runtime errors related to calling reduce on
array iterator
Fix an issue in certain browsers/builds causing a runtime error.
A zod enum has a .options property, which is an array of all the options
for the enum. This is handy for when you need to derive something from a
zod schema.
In this case, we represented the possible focus regions in the zod enum,
then derived a mapping of region names to set of target HTML elements.
Why isn't important, but suffice to say, we were using the .options
property for this.
But actually, we were using .options.values(), then calling .reduce() on
that. An array's .values() method returns an _array iterator_. Array
iterators do not have .reduce() methods!
Except, apparently in some environments they do - it depends on the JS
engine and whether or not polyfills for iterator helpers were included
in the build.
Turns out my dev environment - and most user browsers - do provide
.reduce(), so we didn't catch this error. It took a large deployment and
error monitoring to catch it.
I've refactored the code to totally avoid deriving data from zod in this
way.
---
invokeai/frontend/web/src/common/hooks/focus.ts | 15 +++++++--------
1 file changed, 7 insertions(+), 8 deletions(-)
diff --git a/invokeai/frontend/web/src/common/hooks/focus.ts b/invokeai/frontend/web/src/common/hooks/focus.ts
index 2d0510e751..4e093c5c63 100644
--- a/invokeai/frontend/web/src/common/hooks/focus.ts
+++ b/invokeai/frontend/web/src/common/hooks/focus.ts
@@ -6,7 +6,6 @@ import { atom, computed } from 'nanostores';
import type { RefObject } from 'react';
import { useEffect } from 'react';
import { objectKeys } from 'tsafe';
-import z from 'zod/v4';
/**
* We need to manage focus regions to conditionally enable hotkeys:
@@ -28,10 +27,7 @@ import z from 'zod/v4';
const log = logger('system');
-/**
- * The names of the focus regions.
- */
-const zFocusRegionName = z.enum([
+const REGION_NAMES = [
'launchpad',
'viewer',
'gallery',
@@ -41,13 +37,16 @@ const zFocusRegionName = z.enum([
'workflows',
'progress',
'settings',
-]);
-export type FocusRegionName = z.infer;
+] as const;
+/**
+ * The names of the focus regions.
+ */
+export type FocusRegionName = (typeof REGION_NAMES)[number];
/**
* A map of focus regions to the elements that are part of that region.
*/
-const REGION_TARGETS: Record> = zFocusRegionName.options.values().reduce(
+const REGION_TARGETS: Record> = REGION_NAMES.reduce(
(acc, region) => {
acc[region] = new Set();
return acc;
From 988d7ba24c15a3a0676279c1e5f3634b3be7ad5d Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 08:22:48 +1000
Subject: [PATCH 05/17] chore: bump version to v6.0.1rc1
---
invokeai/version/invokeai_version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py
index 0f607a5d2d..5a9f9c3295 100644
--- a/invokeai/version/invokeai_version.py
+++ b/invokeai/version/invokeai_version.py
@@ -1 +1 @@
-__version__ = "6.0.0"
+__version__ = "6.0.1rc1"
From 694c85b041c7a677b0b3010ef506b44fad1366f1 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 13:27:44 +1000
Subject: [PATCH 06/17] fix(ui): language file filenames
Need to replace the underscores w/ dashes - this was missed in #8246.
---
invokeai/frontend/web/public/locales/{pt_BR.json => pt-BR.json} | 0
invokeai/frontend/web/public/locales/{zh_CN.json => zh-CN.json} | 0
.../frontend/web/public/locales/{zh_Hant.json => zh-Hant.json} | 0
3 files changed, 0 insertions(+), 0 deletions(-)
rename invokeai/frontend/web/public/locales/{pt_BR.json => pt-BR.json} (100%)
rename invokeai/frontend/web/public/locales/{zh_CN.json => zh-CN.json} (100%)
rename invokeai/frontend/web/public/locales/{zh_Hant.json => zh-Hant.json} (100%)
diff --git a/invokeai/frontend/web/public/locales/pt_BR.json b/invokeai/frontend/web/public/locales/pt-BR.json
similarity index 100%
rename from invokeai/frontend/web/public/locales/pt_BR.json
rename to invokeai/frontend/web/public/locales/pt-BR.json
diff --git a/invokeai/frontend/web/public/locales/zh_CN.json b/invokeai/frontend/web/public/locales/zh-CN.json
similarity index 100%
rename from invokeai/frontend/web/public/locales/zh_CN.json
rename to invokeai/frontend/web/public/locales/zh-CN.json
diff --git a/invokeai/frontend/web/public/locales/zh_Hant.json b/invokeai/frontend/web/public/locales/zh-Hant.json
similarity index 100%
rename from invokeai/frontend/web/public/locales/zh_Hant.json
rename to invokeai/frontend/web/public/locales/zh-Hant.json
From 757ecdbf8236b19827335c492a95d8d8bf698864 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 14:21:26 +1000
Subject: [PATCH 07/17] build(ui): downgrade idb-keyval
We have increased error rates after updating this package. Let's try
downgrading to see if that fixes the issue.
---
invokeai/frontend/web/package.json | 2 +-
invokeai/frontend/web/pnpm-lock.yaml | 10 +++++-----
2 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json
index 896f232f67..0dce48b892 100644
--- a/invokeai/frontend/web/package.json
+++ b/invokeai/frontend/web/package.json
@@ -63,7 +63,7 @@
"framer-motion": "^11.10.0",
"i18next": "^25.2.1",
"i18next-http-backend": "^3.0.2",
- "idb-keyval": "^6.2.2",
+ "idb-keyval": "6.2.1",
"jsondiffpatch": "^0.7.3",
"konva": "^9.3.20",
"linkify-react": "^4.3.1",
diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml
index a8ce2b263f..916ee8fb2d 100644
--- a/invokeai/frontend/web/pnpm-lock.yaml
+++ b/invokeai/frontend/web/pnpm-lock.yaml
@@ -81,8 +81,8 @@ importers:
specifier: ^3.0.2
version: 3.0.2
idb-keyval:
- specifier: ^6.2.2
- version: 6.2.2
+ specifier: 6.2.1
+ version: 6.2.1
jsondiffpatch:
specifier: ^0.7.3
version: 0.7.3
@@ -2927,8 +2927,8 @@ packages:
typescript:
optional: true
- idb-keyval@6.2.2:
- resolution: {integrity: sha512-yjD9nARJ/jb1g+CvD0tlhUHOrJ9Sy0P8T9MF3YaLlHnSRpwPfpTX0XIvpmw3gAJUmEu3FiICLBDPXVwyEvrleg==}
+ idb-keyval@6.2.1:
+ resolution: {integrity: sha512-8Sb3veuYCyrZL+VBt9LJfZjLUPWVvqn8tG28VqYNFCo43KHcKuq+b4EiXGeuaLAQWL2YmyDgMp2aSpH9JHsEQg==}
ieee754@1.2.1:
resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==}
@@ -7720,7 +7720,7 @@ snapshots:
optionalDependencies:
typescript: 5.8.3
- idb-keyval@6.2.2: {}
+ idb-keyval@6.2.1: {}
ieee754@1.2.1: {}
From e62d3f01a861265f368f51251b0fa54cf8f00ab2 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 18:29:20 +1000
Subject: [PATCH 08/17] feat(app): better error message for failed model probe
- Old: No valid config found
- New: Unable to determine model type
---
invokeai/backend/model_manager/config.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py
index 7521f2c512..c1e4d4435b 100644
--- a/invokeai/backend/model_manager/config.py
+++ b/invokeai/backend/model_manager/config.py
@@ -187,7 +187,7 @@ class ModelConfigBase(ABC, BaseModel):
else:
return config_cls.from_model_on_disk(mod, **overrides)
- raise InvalidModelConfigException("No valid config found")
+ raise InvalidModelConfigException("Unable to determine model type")
@classmethod
def get_tag(cls) -> Tag:
From bb3e5d16d85278800124fd318ea25895dfa0119d Mon Sep 17 00:00:00 2001
From: Kevin Turner <566360-keturn@users.noreply.gitlab.com>
Date: Fri, 11 Jul 2025 13:06:55 -0700
Subject: [PATCH 09/17] feat(Model Manager): refuse to download a file when
there's insufficient space
---
invokeai/app/services/download/download_default.py | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py
index 150c89bb10..0db9075dde 100644
--- a/invokeai/app/services/download/download_default.py
+++ b/invokeai/app/services/download/download_default.py
@@ -8,6 +8,7 @@ import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
+from shutil import disk_usage
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
import requests
@@ -335,6 +336,14 @@ class DownloadQueueService(DownloadQueueServiceBase):
assert job.download_path
+ free_space = disk_usage(job.download_path.parent).free
+ GB = 2**30
+ self._logger.debug(f"Download is {job.total_bytes / GB:.2f} GB of {free_space / GB:.2f} GB free.")
+ if free_space < job.total_bytes:
+ raise RuntimeError(
+ f"Free disk space {free_space / GB:.2f} GB is not enough for download of {job.total_bytes / GB:.2f} GB."
+ )
+
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
# for code that instead resumes an interrupted download.
if job.download_path.exists():
From d4e903ee2d1b33577678db7b449d5c056bd75fc7 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Fri, 11 Jul 2025 23:36:13 +1000
Subject: [PATCH 10/17] chore: bump version to v6.0.1
---
invokeai/version/invokeai_version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py
index 5a9f9c3295..79a961b4f6 100644
--- a/invokeai/version/invokeai_version.py
+++ b/invokeai/version/invokeai_version.py
@@ -1 +1 @@
-__version__ = "6.0.1rc1"
+__version__ = "6.0.1"
From d9a1efbabf7a272d94b165fde4831aff4a3c18f6 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 00:05:10 +1000
Subject: [PATCH 11/17] fix(ui): staging area images may be slightly too large
---
.../components/SimpleSession/QueueItemPreviewMini.tsx | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
index 6ea9a52eb0..ceac8d3a41 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/QueueItemPreviewMini.tsx
@@ -77,7 +77,7 @@ export const QueueItemPreviewMini = memo(({ item, isSelected, index }: Props) =>
onDoubleClick={onDoubleClick}
>
- {imageDTO && }
+ {imageDTO && }
{!imageLoaded && }
From b23bff1b53efbcdcc167828e27595c23d39c6565 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 00:05:18 +1000
Subject: [PATCH 12/17] fix(ui): center staging area images
---
.../components/SimpleSession/StagingAreaItemsList.tsx | 1 +
1 file changed, 1 insertion(+)
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
index c644608a2e..60a8458871 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/SimpleSession/StagingAreaItemsList.tsx
@@ -94,6 +94,7 @@ const useScrollableStagingArea = (rootRef: RefObject) => {
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
+ viewport.style.textAlign = 'center';
},
},
options: {
From 97439c1daaed44ecc2d3c04851aa206ce858db11 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 00:05:35 +1000
Subject: [PATCH 13/17] fix(ui): native context menu shown on right click on
short fat images
Closes #8254
---
.../features/gallery/components/ImageGrid/GalleryImage.tsx | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
index e2b604fa7a..c6647c6ef3 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
@@ -143,7 +143,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
>(null);
// Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
- const [element, ref] = useState(null);
+ const [element, ref] = useState(null);
const selectIsSelectedForCompare = useMemo(
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
[imageDTO.image_name]
@@ -246,6 +246,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
<>
{
data-selected-for-compare={isSelectedForCompare}
>
}
From 62f52c74a8ebcda6ba477eb721e1bbd9f958ae58 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 00:15:12 +1000
Subject: [PATCH 14/17] fix(ui): linked negative style prompt not passed in
Closes #8256
---
.../src/features/nodes/util/graph/generation/buildSDXLGraph.ts | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
index 5f5c6e3b7f..638f2f8566 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
@@ -78,7 +78,7 @@ export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise
Date: Sat, 12 Jul 2025 14:36:02 +1000
Subject: [PATCH 15/17] fix(ui): gallery dnd
---
.../components/ImageGrid/GalleryImage.tsx | 134 +++++++++---------
1 file changed, 66 insertions(+), 68 deletions(-)
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
index c6647c6ef3..41d821f4d7 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
@@ -23,7 +23,7 @@ import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'fea
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import type { MouseEvent, MouseEventHandler } from 'react';
-import { memo, useCallback, useEffect, useMemo, useState } from 'react';
+import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiImageBold } from 'react-icons/pi';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
@@ -42,45 +42,42 @@ const galleryImageContainerSX = {
'&[data-is-dragging=true]': {
opacity: 0.3,
},
- [`.${GALLERY_IMAGE_CLASS}`]: {
- touchAction: 'none',
- userSelect: 'none',
- webkitUserSelect: 'none',
- position: 'relative',
- justifyContent: 'center',
- alignItems: 'center',
- aspectRatio: '1/1',
- '::before': {
- content: '""',
- display: 'inline-block',
- position: 'absolute',
- top: 0,
- left: 0,
- right: 0,
- bottom: 0,
- pointerEvents: 'none',
- borderRadius: 'base',
- },
- '&[data-selected=true]::before': {
- boxShadow:
- 'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
- },
- '&[data-selected-for-compare=true]::before': {
- boxShadow:
- 'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
- },
- '&:hover::before': {
- boxShadow:
- 'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
- },
- '&:hover[data-selected=true]::before': {
- boxShadow:
- 'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
- },
- '&:hover[data-selected-for-compare=true]::before': {
- boxShadow:
- 'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
- },
+ userSelect: 'none',
+ webkitUserSelect: 'none',
+ position: 'relative',
+ justifyContent: 'center',
+ alignItems: 'center',
+ aspectRatio: '1/1',
+ '::before': {
+ content: '""',
+ display: 'inline-block',
+ position: 'absolute',
+ top: 0,
+ left: 0,
+ right: 0,
+ bottom: 0,
+ pointerEvents: 'none',
+ borderRadius: 'base',
+ },
+ '&[data-selected=true]::before': {
+ boxShadow:
+ 'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
+ },
+ '&[data-selected-for-compare=true]::before': {
+ boxShadow:
+ 'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
+ },
+ '&:hover::before': {
+ boxShadow:
+ 'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
+ },
+ '&:hover[data-selected=true]::before': {
+ boxShadow:
+ 'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
+ },
+ '&:hover[data-selected-for-compare=true]::before': {
+ boxShadow:
+ 'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
} satisfies SystemStyleObject;
@@ -142,8 +139,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const [dragPreviewState, setDragPreviewState] = useState<
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
>(null);
- // Must use callback ref - else chakra's Image fallback prop will break the ref & dnd
- const [element, ref] = useState(null);
+ const ref = useRef(null);
const selectIsSelectedForCompare = useMemo(
() => createSelector(selectGallerySlice, (gallery) => gallery.imageToCompare === imageDTO.image_name),
[imageDTO.image_name]
@@ -156,6 +152,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
const isSelected = useAppSelector(selectIsSelected);
useEffect(() => {
+ const element = ref.current;
if (!element) {
return;
}
@@ -221,7 +218,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
},
})
);
- }, [element, imageDTO, store]);
+ }, [imageDTO, store]);
const [isHovered, setIsHovered] = useState(false);
@@ -240,34 +237,35 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
navigationApi.focusPanelInActiveTab(VIEWER_PANEL_ID);
}, [store]);
- useImageContextMenu(imageDTO, element);
+ useImageContextMenu(imageDTO, ref);
return (
<>
-
-
- }
- objectFit="contain"
- maxW="full"
- maxH="full"
- borderRadius="base"
- />
-
-
-
+
+ }
+ objectFit="contain"
+ maxW="full"
+ maxH="full"
+ borderRadius="base"
+ />
+
+
{dragPreviewState?.type === 'multiple-image' ? createMultipleImageDragPreview(dragPreviewState) : null}
{dragPreviewState?.type === 'single-image' ? createSingleImageDragPreview(dragPreviewState) : null}
>
From 192b00d9692168d330f5ca6afb25d733596cb33c Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 14:36:29 +1000
Subject: [PATCH 16/17] chore: bump version to v6.0.2
---
invokeai/version/invokeai_version.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/invokeai/version/invokeai_version.py b/invokeai/version/invokeai_version.py
index 79a961b4f6..7f229cf189 100644
--- a/invokeai/version/invokeai_version.py
+++ b/invokeai/version/invokeai_version.py
@@ -1 +1 @@
-__version__ = "6.0.1"
+__version__ = "6.0.2"
From 82fb897b6264f8f9436083023b00e0fa65cef287 Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sat, 12 Jul 2025 14:52:24 +1000
Subject: [PATCH 17/17] chore(ui): lint
---
.../features/gallery/components/ImageGrid/GalleryImage.tsx | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
index 41d821f4d7..739ee9c2ed 100644
--- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
+++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx
@@ -1,7 +1,7 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
-import { Box, Flex, Icon, Image } from '@invoke-ai/ui-library';
+import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
@@ -28,8 +28,6 @@ import { PiImageBold } from 'react-icons/pi';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
-const GALLERY_IMAGE_CLASS = 'gallery-image';
-
const galleryImageContainerSX = {
containerType: 'inline-size',
w: 'full',
@@ -255,7 +253,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
data-selected-for-compare={isSelectedForCompare}
>
}