diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index a2ac6b45c8..d224242a56 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -1,7 +1,7 @@ import io import json import traceback -from typing import ClassVar, Optional +from typing import ClassVar, Literal, Optional from fastapi import BackgroundTasks, Body, HTTPException, Path, Query, Request, Response, UploadFile from fastapi.responses import FileResponse @@ -562,3 +562,63 @@ async def get_bulk_download_item( return response except Exception: raise HTTPException(status_code=404) + + +@images_router.get("/collections/counts", operation_id="get_image_collection_counts") +async def get_image_collection_counts( + image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to count."), + categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), + is_intermediate: Optional[bool] = Query(default=None, description="Whether to include intermediate images."), + board_id: Optional[str] = Query( + default=None, + description="The board id to filter by. Use 'none' to find images without a board.", + ), + search_term: Optional[str] = Query(default=None, description="The term to search for"), +) -> dict[str, int]: + """Gets counts for starred and unstarred image collections""" + + try: + counts = ApiDependencies.invoker.services.images.get_collection_counts( + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + return counts + except Exception: + raise HTTPException(status_code=500, detail="Failed to get collection counts") + + +@images_router.get("/collections/{collection}", operation_id="get_image_collection") +async def get_image_collection( + collection: Literal["starred", "unstarred"] = Path(..., description="The collection to retrieve from"), + image_origin: Optional[ResourceOrigin] = Query(default=None, description="The origin of images to list."), + categories: Optional[list[ImageCategory]] = Query(default=None, description="The categories of image to include."), + is_intermediate: Optional[bool] = Query(default=None, description="Whether to list intermediate images."), + board_id: Optional[str] = Query( + default=None, + description="The board id to filter by. Use 'none' to find images without a board.", + ), + offset: int = Query(default=0, description="The offset within the collection"), + limit: int = Query(default=50, description="The number of images to return"), + order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), + search_term: Optional[str] = Query(default=None, description="The term to search for"), +) -> OffsetPaginatedResults[ImageDTO]: + """Gets images from a specific collection (starred or unstarred)""" + + try: + image_dtos = ApiDependencies.invoker.services.images.get_collection_images( + collection=collection, + offset=offset, + limit=limit, + order_dir=order_dir, + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + return image_dtos + except Exception: + raise HTTPException(status_code=500, detail="Failed to get collection images") diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 1211c9762c..de42fa419d 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from datetime import datetime -from typing import Optional +from typing import Literal, Optional from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( @@ -97,3 +97,31 @@ class ImageRecordStorageBase(ABC): def get_most_recent_image_for_board(self, board_id: str) -> Optional[ImageRecord]: """Gets the most recent image for a board.""" pass + + @abstractmethod + def get_collection_counts( + self, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> dict[str, int]: + """Gets counts for starred and unstarred image collections.""" + pass + + @abstractmethod + def get_collection_images( + self, + collection: Literal["starred", "unstarred"], + offset: int = 0, + limit: int = 10, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> OffsetPaginatedResults[ImageRecord]: + """Gets images from a specific collection (starred or unstarred).""" + pass diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 23674e14e6..592004b793 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -1,6 +1,6 @@ import sqlite3 from datetime import datetime -from typing import Optional, Union, cast +from typing import Literal, Optional, Union, cast from invokeai.app.invocations.fields import MetadataField, MetadataFieldValidator from invokeai.app.services.image_records.image_records_base import ImageRecordStorageBase @@ -386,3 +386,181 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): return None return deserialize_image_record(dict(result)) + + def get_collection_counts( + self, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> dict[str, int]: + cursor = self._conn.cursor() + + # Build the base query conditions (same as get_many) + base_query = """--sql + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ + + query_conditions = "" + query_params: list[Union[int, str, bool]] = [] + + if image_origin is not None: + query_conditions += """--sql + AND images.image_origin = ? + """ + query_params.append(image_origin.value) + + if categories is not None: + category_strings = [c.value for c in set(categories)] + placeholders = ",".join("?" * len(category_strings)) + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ + for c in category_strings: + query_params.append(c) + + if is_intermediate is not None: + query_conditions += """--sql + AND images.is_intermediate = ? + """ + query_params.append(is_intermediate) + + if board_id == "none": + query_conditions += """--sql + AND board_images.board_id IS NULL + """ + elif board_id is not None: + query_conditions += """--sql + AND board_images.board_id = ? + """ + query_params.append(board_id) + + if search_term: + query_conditions += """--sql + AND ( + images.metadata LIKE ? + OR images.created_at LIKE ? + ) + """ + query_params.append(f"%{search_term.lower()}%") + query_params.append(f"%{search_term.lower()}%") + + # Get starred count + starred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = TRUE;" + cursor.execute(starred_query, query_params) + starred_count = cast(int, cursor.fetchone()[0]) + + # Get unstarred count + unstarred_query = f"SELECT COUNT(*) {base_query} {query_conditions} AND images.starred = FALSE;" + cursor.execute(unstarred_query, query_params) + unstarred_count = cast(int, cursor.fetchone()[0]) + + return { + "starred_count": starred_count, + "unstarred_count": unstarred_count, + "total_count": starred_count + unstarred_count, + } + + def get_collection_images( + self, + collection: Literal["starred", "unstarred"], + offset: int = 0, + limit: int = 10, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> OffsetPaginatedResults[ImageRecord]: + cursor = self._conn.cursor() + + # Base queries + count_query = """--sql + SELECT COUNT(*) + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ + + images_query = f"""--sql + SELECT {IMAGE_DTO_COLS} + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1 + """ + + query_conditions = "" + query_params: list[Union[int, str, bool]] = [] + + # Add starred/unstarred filter + is_starred = collection == "starred" + query_conditions += """--sql + AND images.starred = ? + """ + query_params.append(is_starred) + + if image_origin is not None: + query_conditions += """--sql + AND images.image_origin = ? + """ + query_params.append(image_origin.value) + + if categories is not None: + category_strings = [c.value for c in set(categories)] + placeholders = ",".join("?" * len(category_strings)) + query_conditions += f"""--sql + AND images.image_category IN ( {placeholders} ) + """ + for c in category_strings: + query_params.append(c) + + if is_intermediate is not None: + query_conditions += """--sql + AND images.is_intermediate = ? + """ + query_params.append(is_intermediate) + + if board_id == "none": + query_conditions += """--sql + AND board_images.board_id IS NULL + """ + elif board_id is not None: + query_conditions += """--sql + AND board_images.board_id = ? + """ + query_params.append(board_id) + + if search_term: + query_conditions += """--sql + AND ( + images.metadata LIKE ? + OR images.created_at LIKE ? + ) + """ + query_params.append(f"%{search_term.lower()}%") + query_params.append(f"%{search_term.lower()}%") + + # Add ordering and pagination + query_pagination = f"""--sql + ORDER BY images.created_at {order_dir.value} LIMIT ? OFFSET ? + """ + + # Execute images query + images_query += query_conditions + query_pagination + ";" + images_params = query_params.copy() + images_params.extend([limit, offset]) + + cursor.execute(images_query, images_params) + result = cast(list[sqlite3.Row], cursor.fetchall()) + images = [deserialize_image_record(dict(r)) for r in result] + + # Execute count query + count_query += query_conditions + ";" + cursor.execute(count_query, query_params) + count = cast(int, cursor.fetchone()[0]) + + return OffsetPaginatedResults(items=images, offset=offset, limit=limit, total=count) diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 5328c1854e..dd998e2578 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Optional +from typing import Callable, Literal, Optional from PIL.Image import Image as PILImageType @@ -147,3 +147,31 @@ class ImageServiceABC(ABC): def delete_images_on_board(self, board_id: str): """Deletes all images on a board.""" pass + + @abstractmethod + def get_collection_counts( + self, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> dict[str, int]: + """Gets counts for starred and unstarred image collections.""" + pass + + @abstractmethod + def get_collection_images( + self, + collection: Literal["starred", "unstarred"], + offset: int = 0, + limit: int = 10, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> OffsetPaginatedResults[ImageDTO]: + """Gets images from a specific collection (starred or unstarred).""" + pass diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 1489a7ce45..83809ad3f4 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from PIL.Image import Image as PILImageType @@ -309,3 +309,68 @@ class ImageService(ImageServiceABC): except Exception as e: self.__invoker.services.logger.error("Problem getting intermediates count") raise e + + def get_collection_counts( + self, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> dict[str, int]: + try: + return self.__invoker.services.image_records.get_collection_counts( + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + except Exception as e: + self.__invoker.services.logger.error("Problem getting collection counts") + raise e + + def get_collection_images( + self, + collection: Literal["starred", "unstarred"], + offset: int = 0, + limit: int = 10, + order_dir: SQLiteDirection = SQLiteDirection.Descending, + image_origin: Optional[ResourceOrigin] = None, + categories: Optional[list[ImageCategory]] = None, + is_intermediate: Optional[bool] = None, + board_id: Optional[str] = None, + search_term: Optional[str] = None, + ) -> OffsetPaginatedResults[ImageDTO]: + try: + results = self.__invoker.services.image_records.get_collection_images( + collection=collection, + offset=offset, + limit=limit, + order_dir=order_dir, + image_origin=image_origin, + categories=categories, + is_intermediate=is_intermediate, + board_id=board_id, + search_term=search_term, + ) + + image_dtos = [ + image_record_to_dto( + image_record=r, + image_url=self.__invoker.services.urls.get_image_url(r.image_name), + thumbnail_url=self.__invoker.services.urls.get_image_url(r.image_name, True), + board_id=self.__invoker.services.board_image_records.get_board_for_image(r.image_name), + ) + for r in results.items + ] + + return OffsetPaginatedResults[ImageDTO]( + items=image_dtos, + offset=results.offset, + limit=results.limit, + total=results.total, + ) + except Exception as e: + self.__invoker.services.logger.error("Problem getting collection images") + raise e diff --git a/invokeai/frontend/web/.eslintrc.js b/invokeai/frontend/web/.eslintrc.js index 3e6498af4c..f4af658ea2 100644 --- a/invokeai/frontend/web/.eslintrc.js +++ b/invokeai/frontend/web/.eslintrc.js @@ -12,11 +12,13 @@ module.exports = { // TODO: ENABLE THIS RULE BEFORE v6.0.0 // 'i18next/no-literal-string': 'error', // https://eslint.org/docs/latest/rules/no-console - 'no-console': 'error', + 'no-console': 'warn', // https://eslint.org/docs/latest/rules/no-promise-executor-return 'no-promise-executor-return': 'error', // https://eslint.org/docs/latest/rules/require-await 'require-await': 'error', + // TODO: ENABLE THIS RULE BEFORE v6.0.0 + 'react/display-name': 'off', 'no-restricted-properties': [ 'error', { diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 9397144751..ec757494f5 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -39,7 +39,6 @@ import { authToastMiddleware } from 'services/api/authToastMiddleware'; import type { JsonObject } from 'type-fest'; import { STORAGE_PREFIX } from './constants'; -import { getDebugLoggerMiddleware } from './middleware/debugLoggerMiddleware'; import { actionSanitizer } from './middleware/devtools/actionSanitizer'; import { actionsDenylist } from './middleware/devtools/actionsDenylist'; import { stateSanitizer } from './middleware/devtools/stateSanitizer'; @@ -177,7 +176,7 @@ export const createStore = (uniqueStoreKey?: string, persist = true) => .concat(api.middleware) .concat(dynamicMiddlewares) .concat(authToastMiddleware) - .concat(getDebugLoggerMiddleware()) + // .concat(getDebugLoggerMiddleware()) .prepend(listenerMiddleware.middleware), enhancers: (getDefaultEnhancers) => { const _enhancers = getDefaultEnhancers().concat(autoBatchEnhancer()); diff --git a/invokeai/frontend/web/src/features/gallery/components/Gallery.tsx b/invokeai/frontend/web/src/features/gallery/components/Gallery.tsx index 1679cc1fb1..ec8f15f72e 100644 --- a/invokeai/frontend/web/src/features/gallery/components/Gallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/Gallery.tsx @@ -25,9 +25,8 @@ import { useBoardName } from 'services/api/hooks/useBoardName'; import { GallerySettingsPopover } from './GallerySettingsPopover/GallerySettingsPopover'; import { GalleryUploadButton } from './GalleryUploadButton'; -import GalleryImageGrid from './ImageGrid/GalleryImageGrid'; -import { GalleryPagination } from './ImageGrid/GalleryPagination'; import { GallerySearch } from './ImageGrid/GallerySearch'; +import { NewGallery } from './NewGallery'; const BASE_STYLES: ChakraProps['sx'] = { fontWeight: 'semibold', @@ -112,8 +111,9 @@ export const GalleryPanel = memo(() => { /> - - + {/* + */} + ); }); diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx new file mode 100644 index 0000000000..9f6807d278 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -0,0 +1,340 @@ +import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; +import { useAppSelector } from 'app/store/storeHooks'; +import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors'; +import React, { memo, useCallback, useMemo, useState } from 'react'; +import { VirtuosoGrid } from 'react-virtuoso'; +import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images'; +import type { ImageDTO } from 'services/api/types'; + +// Placeholder image component for now +const ImagePlaceholder = memo(({ image }: { image: ImageDTO }) => ( + +)); + +ImagePlaceholder.displayName = 'ImagePlaceholder'; + +// Loading skeleton component +const ImageSkeleton = memo(() => ); + +ImageSkeleton.displayName = 'ImageSkeleton'; + +// Hook to manage image data for virtual scrolling +const useVirtualImageData = () => { + const queryArgs = useAppSelector(selectImageCollectionQueryArgs); + + // Get total counts for position mapping + const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs); + + // Cache for loaded image ranges + const [loadedRanges, setLoadedRanges] = useState>(new Map()); + + // Calculate position mappings + const positionInfo = useMemo(() => { + if (!counts) { + return null; + } + + const result = { + totalCount: counts.total_count, + starredCount: counts.starred_count ?? 0, + unstarredCount: counts.unstarred_count ?? 0, + starredEnd: (counts.starred_count ?? 0) - 1, + }; + + return result; + }, [counts]); + + // Clear cache when search parameters change + React.useEffect(() => { + setLoadedRanges(new Map()); + }, [queryArgs.board_id, queryArgs.search_term, queryArgs.categories]); + + // Return flag to indicate when search parameters have changed + const searchParamsChanged = useMemo(() => queryArgs, [queryArgs]); + + // Function to generate cache key for a range + const getRangeKey = useCallback((collection: 'starred' | 'unstarred', offset: number, limit: number) => { + return `${collection}-${offset}-${limit}`; + }, []); + + // Function to get images for a specific position range + const getImagesForRange = useCallback( + (startIndex: number, endIndex: number) => { + if (!positionInfo) { + return []; + } + + const requestedImages: (ImageDTO | null)[] = new Array(endIndex - startIndex + 1).fill(null); + const rangesToLoad: Array<{ + collection: 'starred' | 'unstarred'; + offset: number; + limit: number; + targetStartIndex: number; + }> = []; + + for (let i = startIndex; i <= endIndex; i++) { + const relativeIndex = i - startIndex; + + // Handle case where there are no starred images + if (positionInfo.starredCount === 0 || i >= positionInfo.starredCount) { + // This position is in the unstarred collection + const unstarredOffset = i - positionInfo.starredCount; + const rangeKey = getRangeKey('unstarred', Math.floor(unstarredOffset / 50) * 50, 50); + const cachedRange = loadedRanges.get(rangeKey); + + if (cachedRange) { + const imageIndex = unstarredOffset % 50; + if (imageIndex < cachedRange.length) { + requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null; + } + } else { + // Need to load this range + const rangeOffset = Math.floor(unstarredOffset / 50) * 50; + rangesToLoad.push({ + collection: 'unstarred', + offset: rangeOffset, + limit: 50, + targetStartIndex: i, + }); + } + } else { + // This position is in the starred collection + const starredOffset = i; + const rangeKey = getRangeKey('starred', Math.floor(starredOffset / 50) * 50, 50); + const cachedRange = loadedRanges.get(rangeKey); + + if (cachedRange) { + const imageIndex = starredOffset % 50; + if (imageIndex < cachedRange.length) { + requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null; + } + } else { + // Need to load this range + const rangeOffset = Math.floor(starredOffset / 50) * 50; + rangesToLoad.push({ + collection: 'starred', + offset: rangeOffset, + limit: 50, + targetStartIndex: i, + }); + } + } + } + + return { images: requestedImages, rangesToLoad }; + }, + [positionInfo, loadedRanges, getRangeKey] + ); + + return { + positionInfo, + countsLoading, + getImagesForRange, + setLoadedRanges, + loadedRanges, + searchParamsChanged, + }; +}; + +// Component to handle loading image ranges +const ImageRangeLoader = memo( + ({ + collection, + offset, + limit, + onDataLoaded, + }: { + collection: 'starred' | 'unstarred'; + offset: number; + limit: number; + onDataLoaded: (key: string, images: ImageDTO[]) => void; + }) => { + const queryArgs = useAppSelector(selectImageCollectionQueryArgs); + + const { data } = useGetImageCollectionQuery({ + collection, + offset, + limit, + ...queryArgs, + }); + + // Update cache when data is loaded - use useEffect to avoid state update during render + React.useEffect(() => { + if (data?.items) { + const key = `${collection}-${offset}-${limit}`; + onDataLoaded(key, data.items); + } + }, [data, collection, offset, limit, onDataLoaded]); + + return null; + } +); + +ImageRangeLoader.displayName = 'ImageRangeLoader'; + +export const NewGallery = memo(() => { + const { positionInfo, countsLoading, getImagesForRange, setLoadedRanges, searchParamsChanged } = + useVirtualImageData(); + const [activeRangeLoaders, setActiveRangeLoaders] = useState>(new Set()); + + // Force initial range loading when position info becomes available + const [hasInitiallyLoaded, setHasInitiallyLoaded] = useState(false); + + // Reset hasInitiallyLoaded when search parameters change + React.useEffect(() => { + setHasInitiallyLoaded(false); + setActiveRangeLoaders(new Set()); + }, [searchParamsChanged]); + + // Use useEffect for initial load to avoid state updates during render + React.useEffect(() => { + if (positionInfo && !hasInitiallyLoaded) { + // Force initial load of first 100 positions to ensure we see both starred and unstarred + const initialResult = getImagesForRange(0, Math.min(99, positionInfo.totalCount - 1)); + if (!Array.isArray(initialResult)) { + const { rangesToLoad } = initialResult; + rangesToLoad.forEach((rangeInfo) => { + const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`; + if (!activeRangeLoaders.has(key)) { + setActiveRangeLoaders((prev) => new Set(prev).add(key)); + } + }); + } + setHasInitiallyLoaded(true); + } + }, [positionInfo, hasInitiallyLoaded, getImagesForRange, activeRangeLoaders]); + + // Handle range changes from virtuoso + const handleRangeChanged = useCallback( + (range: { startIndex: number; endIndex: number }) => { + if (!positionInfo) { + return; + } + + const result = getImagesForRange(range.startIndex, range.endIndex); + if (!Array.isArray(result)) { + const { rangesToLoad } = result; + + // Start loading any missing ranges + rangesToLoad.forEach((rangeInfo) => { + const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`; + if (!activeRangeLoaders.has(key)) { + setActiveRangeLoaders((prev) => new Set(prev).add(key)); + } + }); + } + }, + [positionInfo, getImagesForRange, activeRangeLoaders] + ); + + // Handle when range data is loaded + const handleDataLoaded = useCallback( + (key: string, images: ImageDTO[]) => { + setLoadedRanges((prev) => new Map(prev).set(key, images)); + setActiveRangeLoaders((prev) => { + const next = new Set(prev); + next.delete(key); + return next; + }); + }, + [setLoadedRanges] + ); + + const computeItemKey = useCallback( + (index: number) => { + const result = getImagesForRange(index, index); + if (Array.isArray(result)) { + return `loading-${index}`; + } + const { images } = result; + const image = images[0]; + return image ? `image-${index}-${image.image_name}` : `skeleton-${index}`; + }, + [getImagesForRange] + ); + + // Render item at specific index + const itemContent = useCallback( + (index: number) => { + if (!positionInfo) { + return ; + } + + const result = getImagesForRange(index, index); + if (Array.isArray(result)) { + return ; + } + + const { images } = result; + const image = images[0]; + + if (image) { + return ; + } + + return ; + }, + [positionInfo, getImagesForRange] + ); + + if (countsLoading) { + return ( + + + Loading gallery... + + ); + } + + if (!positionInfo || positionInfo.totalCount === 0) { + return ( + + No images found + + ); + } + + return ( + + {/* Render active range loaders */} + {Array.from(activeRangeLoaders).map((key) => { + const [collection, offset, limit] = key.split('-'); + return ( + + ); + })} + + {/* Virtualized grid */} + + + ); +}); + +NewGallery.displayName = 'NewGallery'; + +const style = { height: '100%', width: '100%' }; + +const ListComponent = forwardRef((props, ref) => ( + +)); + +const ItemComponent = forwardRef((props, ref) => ); + +const components = { + Item: ItemComponent, + List: ListComponent, +}; diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts index a3f10d47e9..3e132a2d72 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts @@ -4,7 +4,7 @@ import { skipToken } from '@reduxjs/toolkit/query'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { selectGallerySlice } from 'features/gallery/store/gallerySlice'; import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types'; -import type { ListBoardsArgs, ListImagesArgs } from 'services/api/types'; +import type { ListBoardsArgs, ListImagesArgs, SQLiteDirection } from 'services/api/types'; export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0)); export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1)); @@ -38,6 +38,14 @@ export const selectListBoardsQueryArgs = createMemoizedSelector( export const selectAutoAddBoardId = createSelector(selectGallerySlice, (gallery) => gallery.autoAddBoardId); export const selectSelectedBoardId = createSelector(selectGallerySlice, (gallery) => gallery.selectedBoardId); + +export const selectImageCollectionQueryArgs = createMemoizedSelector(selectGallerySlice, (gallery) => ({ + board_id: gallery.selectedBoardId === 'none' ? undefined : gallery.selectedBoardId, + categories: gallery.galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES, + search_term: gallery.searchTerm || undefined, + order_dir: gallery.orderDir as SQLiteDirection, + is_intermediate: false, +})); export const selectAutoAssignBoardOnClick = createSelector( selectGallerySlice, (gallery) => gallery.autoAssignBoardOnClick diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index 0938d4a2a9..f3d867b7d3 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -77,7 +77,12 @@ export const imagesApi = api.injectEndpoints({ }), clearIntermediates: build.mutation({ query: () => ({ url: buildImagesUrl('intermediates'), method: 'DELETE' }), - invalidatesTags: ['IntermediatesCount', 'InvocationCacheStatus'], + invalidatesTags: [ + 'IntermediatesCount', + 'InvocationCacheStatus', + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, + ], }), getImageDTO: build.query({ query: (image_name) => ({ url: buildImagesUrl(`i/${image_name}`) }), @@ -106,7 +111,11 @@ export const imagesApi = api.injectEndpoints({ // We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries // that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags // will force those queries to re-fetch, and the requests will of course 404. - return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards); + return [ + ...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards), + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, + ]; }, }), deleteImages: build.mutation< @@ -125,7 +134,11 @@ export const imagesApi = api.injectEndpoints({ // We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries // that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags // will force those queries to re-fetch, and the requests will of course 404. - return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards); + return [ + ...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards), + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, + ]; }, }), deleteUncategorizedImages: build.mutation< @@ -140,7 +153,11 @@ export const imagesApi = api.injectEndpoints({ // We ignore the deleted images when getting tags to invalidate. If we did not, we will invalidate the queries // that fetch image DTOs, metadata, and workflows. But we have just deleted those images! Invalidating the tags // will force those queries to re-fetch, and the requests will of course 404. - return getTagsToInvalidateForBoardAffectingMutation(result.affected_boards); + return [ + ...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards), + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, + ]; }, }), /** @@ -184,6 +201,8 @@ export const imagesApi = api.injectEndpoints({ return [ ...getTagsToInvalidateForImageMutation(result.starred_images), ...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards), + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, ]; }, }), @@ -206,6 +225,8 @@ export const imagesApi = api.injectEndpoints({ return [ ...getTagsToInvalidateForImageMutation(result.unstarred_images), ...getTagsToInvalidateForBoardAffectingMutation(result.affected_boards), + 'ImageCollectionCounts', + { type: 'ImageCollection', id: LIST_TAG }, ]; }, }), @@ -399,6 +420,52 @@ export const imagesApi = api.injectEndpoints({ }, }), }), + /** + * Get counts for starred and unstarred image collections + */ + getImageCollectionCounts: build.query< + paths['/api/v1/images/collections/counts']['get']['responses']['200']['content']['application/json'], + paths['/api/v1/images/collections/counts']['get']['parameters']['query'] + >({ + query: (queryArgs) => ({ + url: buildImagesUrl('collections/counts'), + method: 'GET', + params: queryArgs, + }), + providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'], + }), + /** + * Get images from a specific collection (starred or unstarred) + */ + getImageCollection: build.query< + paths['/api/v1/images/collections/{collection}']['get']['responses']['200']['content']['application/json'], + paths['/api/v1/images/collections/{collection}']['get']['parameters']['path'] & + paths['/api/v1/images/collections/{collection}']['get']['parameters']['query'] + >({ + query: ({ collection, ...queryArgs }) => ({ + url: buildImagesUrl(`collections/${collection}`), + method: 'GET', + params: queryArgs, + }), + providesTags: (result, error, { collection, board_id, categories }) => { + const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`; + return [{ type: 'ImageCollection', id: cacheKey }, 'FetchOnReconnect']; + }, + async onQueryStarted(_, { dispatch, queryFulfilled }) { + // Populate the getImageDTO cache with these images, similar to listImages + const res = await queryFulfilled; + const imageDTOs = res.data.items; + const updates: Param0 = []; + for (const imageDTO of imageDTOs) { + updates.push({ + endpointName: 'getImageDTO', + arg: imageDTO.image_name, + value: imageDTO, + }); + } + dispatch(imagesApi.util.upsertQueryEntries(updates)); + }, + }), }), }); @@ -420,6 +487,9 @@ export const { useStarImagesMutation, useUnstarImagesMutation, useBulkDownloadImagesMutation, + useGetImageCollectionCountsQuery, + useGetImageCollectionQuery, + useLazyGetImageCollectionQuery, } = imagesApi; /** diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index 123245fdfe..b2b65b4d99 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -23,6 +23,8 @@ const tagTypes = [ 'ImageList', 'ImageMetadata', 'ImageWorkflow', + 'ImageCollectionCounts', + 'ImageCollection', 'ImageMetadataFromFile', 'IntermediatesCount', 'SessionQueueItem', diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 827bc71e9e..a33e90312c 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -752,6 +752,46 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/images/collections/counts": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Image Collection Counts + * @description Gets counts for starred and unstarred image collections + */ + get: operations["get_image_collection_counts"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; + "/api/v1/images/collections/{collection}": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + /** + * Get Image Collection + * @description Gets images from a specific collection (starred or unstarred) + */ + get: operations["get_image_collection"]; + put?: never; + post?: never; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/boards/": { parameters: { query?: never; @@ -23675,6 +23715,97 @@ export interface operations { }; }; }; + get_image_collection_counts: { + parameters: { + query?: { + /** @description The origin of images to count. */ + image_origin?: components["schemas"]["ResourceOrigin"] | null; + /** @description The categories of image to include. */ + categories?: components["schemas"]["ImageCategory"][] | null; + /** @description Whether to include intermediate images. */ + is_intermediate?: boolean | null; + /** @description The board id to filter by. Use 'none' to find images without a board. */ + board_id?: string | null; + /** @description The term to search for */ + search_term?: string | null; + }; + header?: never; + path?: never; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": { + [key: string]: number; + }; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + get_image_collection: { + parameters: { + query?: { + /** @description The origin of images to list. */ + image_origin?: components["schemas"]["ResourceOrigin"] | null; + /** @description The categories of image to include. */ + categories?: components["schemas"]["ImageCategory"][] | null; + /** @description Whether to list intermediate images. */ + is_intermediate?: boolean | null; + /** @description The board id to filter by. Use 'none' to find images without a board. */ + board_id?: string | null; + /** @description The offset within the collection */ + offset?: number; + /** @description The number of images to return */ + limit?: number; + /** @description The order of sort */ + order_dir?: components["schemas"]["SQLiteDirection"]; + /** @description The term to search for */ + search_term?: string | null; + }; + header?: never; + path: { + /** @description The collection to retrieve from */ + collection: "starred" | "unstarred"; + }; + cookie?: never; + }; + requestBody?: never; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["OffsetPaginatedResults_ImageDTO_"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; list_boards: { parameters: { query?: {