diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index fce598f1f4..e3187822d3 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -622,3 +622,31 @@ async def get_image_collection( return image_dtos except Exception: raise HTTPException(status_code=500, detail="Failed to get collection images") + + +@images_router.get("/names", operation_id="get_image_names") +async def get_image_names( + image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."), + categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), + is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."), + board_id: Optional[str] = Query( + default=None, + description="The board id to filter by. Use 'none' to find images without a board.", + ), + order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), + search_term: Optional[str] = Query(default=None, description="The term to search for"), +) -> list[str]: + """Gets ordered list of all image names (starred first, then unstarred)""" + + try: + image_names = ApiDependencies.invoker.services.images.get_image_names( + order_dir=order_dir, + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + return image_names + except Exception: + raise HTTPException(status_code=500, detail="Failed to get image names") diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 86fc95fb98..e640e3facb 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -126,3 +126,16 @@ class ImageRecordStorageBase(ABC): ) -> OffsetPaginatedResults[ImageRecord]: """Gets images from a specific collection (starred or unstarred).""" pass + + @abstractmethod + def get_image_names( + self, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> list[str]: + """Gets ordered list of all image names (starred first, then unstarred).""" + pass diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 1a6209eb4c..bf0706d426 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -561,3 +561,76 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): count = cast(int, cursor.fetchone()[0]) return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count) + + def get_image_names( + self, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> list[str]: + cursor = self._conn.cursor() + + # Base query to get image names in order (starred first, then unstarred) + query = """--sql + SELECT images.image_name + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ + + query_conditions = "" + query_params: list[Union[int, str, bool]] = [] + + if image_origin is not None: + query_conditions += """--sql + AND images.image_origin = ? + """ + query_params.append(image_origin.value) + + if categories is not None: + category_strings = [c.value for c in set(categories)] + placeholders = ",".join("?" * len(category_strings)) + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ + for c in category_strings: + query_params.append(c) + + if is_intermediate is not None: + query_conditions += """--sql + AND images.is_intermediate = ? + """ + query_params.append(is_intermediate) + + if board_id == "none": + query_conditions += """--sql + AND board_images.board_id IS NULL + """ + elif board_id is not None: + query_conditions += """--sql + AND board_images.board_id = ? + """ + query_params.append(board_id) + + if search_term: + query_conditions += """--sql + AND ( + images.metadata LIKE ? + OR images.created_at LIKE ? + ) + """ + query_params.append(f"%{search_term.lower()}%") + query_params.append(f"%{search_term.lower()}%") + + # Order by starred first, then by created_at + query += query_conditions + f"""--sql + ORDER BY images.starred DESC, images.created_at {order_dir.value} + """ + + cursor.execute(query, query_params) + result = cast(list[sqlite3.Row], cursor.fetchall()) + + return [row[0] for row in result] diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 4503f822c0..4886d31cca 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -176,3 +176,16 @@ class ImageServiceABC(ABC): ) -> OffsetPaginatedResults[ImageDTO]: """Gets images from a specific collection (starred or unstarred).""" pass + + @abstractmethod + def get_image_names( + self, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> list[str]: + """Gets ordered list of all image names (starred first, then unstarred).""" + pass diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 6ab0f1dc8d..62a1262dc8 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -376,3 +376,25 @@ class ImageService(ImageServiceABC): except Exception as e: self.__invoker.services.logger.error("Problem getting collection images") raise e + + def get_image_names( + self, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> list[str]: + try: + return self.__invoker.services.image_records.get_image_names( + order_dir=order_dir, + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + except Exception as e: + self.__invoker.services.logger.error("Problem getting image names") + raise e diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts index b68920b96a..c41ef7c654 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts @@ -5,7 +5,7 @@ import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySe import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import { uniq } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; -import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types'; +import type { ImageCategory, SQLiteDirection } from 'services/api/types'; // Type for image collection query arguments type ImageCollectionQueryArgs = { @@ -17,53 +17,12 @@ type ImageCollectionQueryArgs = { }; /** - * Helper function to get all cached image data from collection queries - * Returns a combined array of starred images followed by unstarred images + * Helper function to get cached image names list for selection operations + * Returns an ordered array of image names (starred first, then unstarred) */ -const getCachedImageList = (state: RootState, queryArgs: ImageCollectionQueryArgs): ImageDTO[] => { - const countsQueryResult = imagesApi.endpoints.getImageCollectionCounts.select(queryArgs)(state); - - if (!countsQueryResult.data) { - return []; - } - - const { starred_count, unstarred_count } = countsQueryResult.data; - - const imageDTOs: ImageDTO[] = []; - - // Add starred images first (in order) - if (starred_count > 0) { - for (let offset = 0; offset < starred_count; offset += 50) { - const queryResult = imagesApi.endpoints.getImageCollection.select({ - collection: 'starred', - offset, - limit: 50, - ...queryArgs, - })(state); - - if (queryResult.data?.items) { - imageDTOs.push(...queryResult.data.items); - } - } - } - - // Add unstarred images (in order) - if (unstarred_count > 0) { - for (let offset = 0; offset < unstarred_count; offset += 50) { - const queryResult = imagesApi.endpoints.getImageCollection.select({ - collection: 'unstarred', - offset, - limit: 50, - ...queryArgs, - })(state); - - if (queryResult.data?.items) { - imageDTOs.push(...queryResult.data.items); - } - } - } - - return imageDTOs; +const getCachedImageNames = (state: RootState, queryArgs: ImageCollectionQueryArgs): string[] => { + const queryResult = imagesApi.endpoints.getImageNames.select(queryArgs)(state); + return queryResult.data || []; }; export const galleryImageClicked = createAction<{ @@ -93,12 +52,12 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen const state = getState(); const queryArgs = selectImageCollectionQueryArgs(state); - // Get all cached image data - const imageDTOs = getCachedImageList(state, queryArgs); + // Get cached image names for selection operations + const imageNames = getCachedImageNames(state, queryArgs); - // If we don't have the image data cached, we can't perform selection operations - // This can happen if the user clicks on an image before all data is loaded - if (imageDTOs.length === 0) { + // If we don't have the image names cached, we can't perform selection operations + // This can happen if the user clicks on an image before the names are loaded + if (imageNames.length === 0) { // For basic click without modifiers, we can still set selection if (!shiftKey && !ctrlKey && !metaKey && !altKey) { dispatch(selectionChanged([imageName])); @@ -117,13 +76,13 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen } else if (shiftKey) { const rangeEndImageName = imageName; const lastSelectedImage = selection.at(-1); - const lastClickedIndex = imageDTOs.findIndex((n) => n.image_name === lastSelectedImage); - const currentClickedIndex = imageDTOs.findIndex((n) => n.image_name === rangeEndImageName); + const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage); + const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName); if (lastClickedIndex > -1 && currentClickedIndex > -1) { // We have a valid range! const start = Math.min(lastClickedIndex, currentClickedIndex); const end = Math.max(lastClickedIndex, currentClickedIndex); - const imagesToSelect = imageDTOs.slice(start, end + 1).map(({ image_name }) => image_name); + const imagesToSelect = imageNames.slice(start, end + 1); dispatch(selectionChanged(uniq(selection.concat(imagesToSelect)))); } } else if (ctrlKey || metaKey) { diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index 21162829ad..b4838ba142 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -14,7 +14,11 @@ import type { VirtuosoGridHandle, } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso'; -import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images'; +import { + useGetImageCollectionCountsQuery, + useGetImageCollectionQuery, + useGetImageNamesQuery, +} from 'services/api/endpoints/images'; import type { ImageCategory, SQLiteDirection } from 'services/api/types'; import { useDebounce } from 'use-debounce'; @@ -167,6 +171,10 @@ export const NewGallery = memo(() => { const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions); + // Load image names for selection operations - this is lightweight and ensures + // selection operations work even before image data is fully loaded + useGetImageNamesQuery(queryArgs); + // Reset scroll position when query parameters change useEffect(() => { if (virtuosoRef.current && counts.total_count > 0) { diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index cb76d47865..edf8786e32 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -5,11 +5,13 @@ import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/type import type { components, paths } from 'services/api/schema'; import type { GraphAndWorkflowResponse, + ImageCategory, ImageDTO, ImageUploadEntryRequest, ImageUploadEntryResponse, ListImagesArgs, ListImagesResponse, + SQLiteDirection, UploadImageArg, } from 'services/api/types'; import { getCategories, getListImagesUrl } from 'services/api/util'; @@ -471,6 +473,26 @@ export const imagesApi = api.injectEndpoints({ dispatch(imagesApi.util.upsertQueryEntries(updates)); }, }), + /** + * Get ordered list of image names for selection operations + */ + getImageNames: build.query< + string[], + { + image_origin?: 'internal' | 'external' | null; + categories?: ImageCategory[] | null; + is_intermediate?: boolean | null; + board_id?: string | null; + search_term?: string | null; + order_dir?: SQLiteDirection; + } + >({ + query: (queryArgs) => ({ + url: buildImagesUrl('names', queryArgs), + method: 'GET', + }), + providesTags: ['ImageNameList', 'FetchOnReconnect'], + }), }), }); @@ -495,6 +517,7 @@ export const { useGetImageCollectionCountsQuery, useGetImageCollectionQuery, useLazyGetImageCollectionQuery, + useGetImageNamesQuery, } = imagesApi; /**