From ab5cb2c26464b2dd5074babfa55887fe3aefa0d6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 26 Jun 2025 16:44:51 +1000 Subject: [PATCH] refactor: optimistic gallery updates --- invokeai/app/api/routers/images.py | 37 +++- .../image_records/image_records_base.py | 5 +- .../image_records/image_records_common.py | 7 + .../image_records/image_records_sqlite.py | 52 ++++-- invokeai/app/services/images/images_base.py | 5 +- .../app/services/images/images_default.py | 3 +- .../middleware/listenerMiddleware/index.ts | 4 - .../listeners/galleryImageClicked.ts | 77 -------- .../components/ImageGrid/GalleryImage.tsx | 75 ++++++-- .../gallery/components/NewGallery.tsx | 171 +++++++++--------- .../components/use-gallery-image-names.ts | 6 +- .../hooks/useRangeBasedImageFetching.ts | 86 +++++++++ .../ui/layouts/auto-layout-context.tsx | 17 +- .../ui/layouts/canvas-tab-auto-layout.tsx | 2 +- .../ui/layouts/generate-tab-auto-layout.tsx | 2 +- .../ui/layouts/upscaling-tab-auto-layout.tsx | 2 +- .../ui/layouts/workflows-tab-auto-layout.tsx | 2 +- .../web/src/services/api/endpoints/images.ts | 36 +++- .../frontend/web/src/services/api/schema.ts | 86 ++++++++- .../frontend/web/src/services/api/types.ts | 2 + .../services/api/util/optimisticUpdates.ts | 90 +++++++++ .../services/events/onInvocationComplete.tsx | 141 ++++++--------- 22 files changed, 605 insertions(+), 303 deletions(-) delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts create mode 100644 invokeai/frontend/web/src/features/gallery/hooks/useRangeBasedImageFetching.ts create mode 100644 invokeai/frontend/web/src/services/api/util/optimisticUpdates.ts diff --git a/invokeai/app/api/routers/images.py b/invokeai/app/api/routers/images.py index a9bcc9f768..dfc4d63d19 100644 --- a/invokeai/app/api/routers/images.py +++ b/invokeai/app/api/routers/images.py @@ -14,6 +14,7 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, + ImageNamesResult, ImageRecordChanges, ResourceOrigin, ) @@ -576,11 +577,11 @@ async def get_image_names( order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"), starred_first: bool = Query(default=True, description="Whether to sort by starred images first"), search_term: Optional[str] = Query(default=None, description="The term to search for"), -) -> list[str]: - """Gets ordered list of all image names (starred first, then unstarred)""" +) -> ImageNamesResult: + """Gets ordered list of image names with metadata for optimistic updates""" try: - image_names = ApiDependencies.invoker.services.images.get_image_names( + result = ApiDependencies.invoker.services.images.get_image_names( starred_first=starred_first, order_dir=order_dir, image_origin=image_origin, @@ -589,6 +590,34 @@ async def get_image_names( board_id=board_id, search_term=search_term, ) - return image_names + return result except Exception: raise HTTPException(status_code=500, detail="Failed to get image names") + + +@images_router.post( + "/images_by_names", + operation_id="get_images_by_names", + responses={200: {"model": list[ImageDTO]}}, +) +async def get_images_by_names( + image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"), +) -> list[ImageDTO]: + """Gets image DTOs for the specified image names. Maintains order of input names.""" + + try: + image_service = ApiDependencies.invoker.services.images + + # Fetch DTOs preserving the order of requested names + image_dtos: list[ImageDTO] = [] + for name in image_names: + try: + dto = image_service.get_dto(name) + image_dtos.append(dto) + except Exception: + # Skip missing images - they may have been deleted between name fetch and DTO fetch + continue + + return image_dtos + except Exception: + raise HTTPException(status_code=500, detail="Failed to get image DTOs") diff --git a/invokeai/app/services/image_records/image_records_base.py b/invokeai/app/services/image_records/image_records_base.py index 128ced7b09..ff271e2394 100644 --- a/invokeai/app/services/image_records/image_records_base.py +++ b/invokeai/app/services/image_records/image_records_base.py @@ -5,6 +5,7 @@ from typing import Optional from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, + ImageNamesResult, ImageRecord, ImageRecordChanges, ResourceOrigin, @@ -108,6 +109,6 @@ class ImageRecordStorageBase(ABC): is_intermediate: Optional[bool] = None, board_id: Optional[str] = None, search_term: Optional[str] = None, - ) -> list[str]: - """Gets ordered list of all image names (starred first, then unstarred).""" + ) -> ImageNamesResult: + """Gets ordered list of image names with metadata for optimistic updates.""" pass diff --git a/invokeai/app/services/image_records/image_records_common.py b/invokeai/app/services/image_records/image_records_common.py index eee3c0cd9b..d91b0653ba 100644 --- a/invokeai/app/services/image_records/image_records_common.py +++ b/invokeai/app/services/image_records/image_records_common.py @@ -212,3 +212,10 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord: class ImageCollectionCounts(BaseModel): starred_count: int = Field(description="The number of starred images in the collection.") unstarred_count: int = Field(description="The number of unstarred images in the collection.") + + +class ImageNamesResult(BaseModel): + """Response containing ordered image names with metadata for optimistic updates.""" + image_names: list[str] = Field(description="Ordered list of image names") + starred_count: int = Field(description="Number of starred images (when starred_first=True)") + total_count: int = Field(description="Total number of images matching the query") diff --git a/invokeai/app/services/image_records/image_records_sqlite.py b/invokeai/app/services/image_records/image_records_sqlite.py index 3086880560..fe83c772b5 100644 --- a/invokeai/app/services/image_records/image_records_sqlite.py +++ b/invokeai/app/services/image_records/image_records_sqlite.py @@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt from invokeai.app.services.image_records.image_records_common import ( IMAGE_DTO_COLS, ImageCategory, + ImageNamesResult, ImageRecord, ImageRecordChanges, ImageRecordDeleteException, @@ -396,17 +397,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): is_intermediate: Optional[bool] = None, board_id: Optional[str] = None, search_term: Optional[str] = None, - ) -> list[str]: + ) -> ImageNamesResult: cursor = self._conn.cursor() - # Base query to get image names in order (starred first, then unstarred) - query = """--sql - SELECT images.image_name - FROM images - LEFT JOIN board_images ON board_images.image_name = images.image_name - WHERE 1=1 - """ - + # Build query conditions (reused for both starred count and image names queries) query_conditions = "" query_params: list[Union[int, str, bool]] = [] @@ -451,22 +445,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase): query_params.append(f"%{search_term.lower()}%") query_params.append(f"%{search_term.lower()}%") + # Get starred count if starred_first is enabled + starred_count = 0 if starred_first: - query += ( - query_conditions - + f"""--sql + starred_count_query = f"""--sql + SELECT COUNT(*) + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE images.starred = TRUE AND (1=1{query_conditions}) + """ + cursor.execute(starred_count_query, query_params) + starred_count = cast(int, cursor.fetchone()[0]) + + # Get all image names with proper ordering + if starred_first: + names_query = f"""--sql + SELECT images.image_name + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1{query_conditions} ORDER BY images.starred DESC, images.created_at {order_dir.value} """ - ) else: - query += ( - query_conditions - + f"""--sql + names_query = f"""--sql + SELECT images.image_name + FROM images + LEFT JOIN board_images ON board_images.image_name = images.image_name + WHERE 1=1{query_conditions} ORDER BY images.created_at {order_dir.value} """ - ) - cursor.execute(query, query_params) + cursor.execute(names_query, query_params) result = cast(list[sqlite3.Row], cursor.fetchall()) + image_names = [row[0] for row in result] - return [row[0] for row in result] + return ImageNamesResult( + image_names=image_names, + starred_count=starred_count, + total_count=len(image_names) + ) diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 3bf832cc71..e1fe02c1ec 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType from invokeai.app.invocations.fields import MetadataField from invokeai.app.services.image_records.image_records_common import ( ImageCategory, + ImageNamesResult, ImageRecord, ImageRecordChanges, ResourceOrigin, @@ -158,6 +159,6 @@ class ImageServiceABC(ABC): is_intermediate: Optional[bool] = None, board_id: Optional[str] = None, search_term: Optional[str] = None, - ) -> list[str]: - """Gets ordered list of all image names.""" + ) -> ImageNamesResult: + """Gets ordered list of image names with metadata for optimistic updates.""" pass diff --git a/invokeai/app/services/images/images_default.py b/invokeai/app/services/images/images_default.py index 4547d46c04..64ef0751b2 100644 --- a/invokeai/app/services/images/images_default.py +++ b/invokeai/app/services/images/images_default.py @@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import ( ) from invokeai.app.services.image_records.image_records_common import ( ImageCategory, + ImageNamesResult, ImageRecord, ImageRecordChanges, ImageRecordDeleteException, @@ -319,7 +320,7 @@ class ImageService(ImageServiceABC): is_intermediate: Optional[bool] = None, board_id: Optional[str] = None, search_term: Optional[str] = None, - ) -> list[str]: + ) -> ImageNamesResult: try: return self.__invoker.services.image_records.get_image_names( starred_first=starred_first, diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index 3629eb345f..8f0a41d53b 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -10,7 +10,6 @@ import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddlew import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload'; import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear'; import { addEnsureImageIsSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/ensureImageIsSelectedListener'; -import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked'; import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema'; import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard'; import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard'; @@ -44,9 +43,6 @@ addImageUploadedFulfilledListener(startAppListening); // Image deleted addDeleteBoardAndImagesFulfilledListener(startAppListening); -// Gallery -addGalleryImageClickedListener(startAppListening); - // User Invoked addEnqueueRequestedLinear(startAppListening); addEnqueueRequestedUpscale(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts deleted file mode 100644 index e937e0f710..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/galleryImageClicked.ts +++ /dev/null @@ -1,77 +0,0 @@ -import { createAction } from '@reduxjs/toolkit'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { uniq } from 'es-toolkit/compat'; -import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors'; -import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; -import { imagesApi } from 'services/api/endpoints/images'; - -export const galleryImageClicked = createAction<{ - imageName: string; - shiftKey: boolean; - ctrlKey: boolean; - metaKey: boolean; - altKey: boolean; -}>('gallery/imageClicked'); - -/** - * This listener handles the logic for selecting images in the gallery. - * - * Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time - * the selection changed, the callback got recreated and all images rerendered. This could easily block for - * hundreds of ms, more for lower end devices. - * - * Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery - * is much more responsive. - */ - -export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: galleryImageClicked, - effect: (action, { dispatch, getState }) => { - const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload; - const state = getState(); - const queryArgs = selectListImageNamesQueryArgs(state); - const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? []; - - // If we don't have the image names cached, we can't perform selection operations - // This can happen if the user clicks on an image before the names are loaded - if (imageNames.length === 0) { - // For basic click without modifiers, we can still set selection - if (!shiftKey && !ctrlKey && !metaKey && !altKey) { - dispatch(selectionChanged([imageName])); - } - return; - } - - const selection = state.gallery.selection; - - if (altKey) { - if (state.gallery.imageToCompare === imageName) { - dispatch(imageToCompareChanged(null)); - } else { - dispatch(imageToCompareChanged(imageName)); - } - } else if (shiftKey) { - const rangeEndImageName = imageName; - const lastSelectedImage = selection.at(-1); - const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage); - const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName); - if (lastClickedIndex > -1 && currentClickedIndex > -1) { - // We have a valid range! - const start = Math.min(lastClickedIndex, currentClickedIndex); - const end = Math.max(lastClickedIndex, currentClickedIndex); - const imagesToSelect = imageNames.slice(start, end + 1); - dispatch(selectionChanged(uniq(selection.concat(imagesToSelect)))); - } - } else if (ctrlKey || metaKey) { - if (selection.some((n) => n === imageName) && selection.length > 1) { - dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName)))); - } else { - dispatch(selectionChanged(uniq(selection.concat(imageName)))); - } - } else { - dispatch(selectionChanged([imageName])); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx index 2067d75f49..34436fceda 100644 --- a/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/ImageGrid/GalleryImage.tsx @@ -3,9 +3,10 @@ import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, Flex, Image } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; -import { galleryImageClicked } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked'; import { useAppStore } from 'app/store/nanostores/store'; +import type { AppDispatch, AppGetState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; +import { uniq } from 'es-toolkit'; import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd'; import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPreviewMultipleImage'; import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage'; @@ -15,11 +16,13 @@ import { firefoxDndFix } from 'features/dnd/util'; import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu'; import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons'; import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId'; -import { imageToCompareChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice'; +import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors'; +import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'features/gallery/store/gallerySlice'; import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context'; import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared'; -import type { MouseEventHandler } from 'react'; +import type { MouseEvent, MouseEventHandler } from 'react'; import { memo, useCallback, useEffect, useId, useMemo, useRef, useState } from 'react'; +import { imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; // This class name is used to calculate the number of images that fit in the gallery @@ -83,6 +86,54 @@ interface Props { imageDTO: ImageDTO; } +const buildOnClick = + (imageName: string, dispatch: AppDispatch, getState: AppGetState) => (e: MouseEvent) => { + const { shiftKey, ctrlKey, metaKey, altKey } = e; + const state = getState(); + const queryArgs = selectListImageNamesQueryArgs(state); + const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data?.image_names ?? []; + + // If we don't have the image names cached, we can't perform selection operations + // This can happen if the user clicks on an image before the names are loaded + if (imageNames.length === 0) { + // For basic click without modifiers, we can still set selection + if (!shiftKey && !ctrlKey && !metaKey && !altKey) { + dispatch(selectionChanged([imageName])); + } + return; + } + + const selection = state.gallery.selection; + + if (altKey) { + if (state.gallery.imageToCompare === imageName) { + dispatch(imageToCompareChanged(null)); + } else { + dispatch(imageToCompareChanged(imageName)); + } + } else if (shiftKey) { + const rangeEndImageName = imageName; + const lastSelectedImage = selection.at(-1); + const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage); + const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName); + if (lastClickedIndex > -1 && currentClickedIndex > -1) { + // We have a valid range! + const start = Math.min(lastClickedIndex, currentClickedIndex); + const end = Math.max(lastClickedIndex, currentClickedIndex); + const imagesToSelect = imageNames.slice(start, end + 1); + dispatch(selectionChanged(uniq(selection.concat(imagesToSelect)))); + } + } else if (ctrlKey || metaKey) { + if (selection.some((n) => n === imageName) && selection.length > 1) { + dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName)))); + } else { + dispatch(selectionChanged(uniq(selection.concat(imageName)))); + } + } else { + dispatch(selectionChanged([imageName])); + } + }; + export const GalleryImage = memo(({ imageDTO }: Props) => { const store = useAppStore(); const autoLayoutContext = useAutoLayoutContext(); @@ -192,20 +243,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { setIsHovered(false); }, []); - const onClick = useCallback>( - (e) => { - store.dispatch( - galleryImageClicked({ - imageName: imageDTO.image_name, - shiftKey: e.shiftKey, - ctrlKey: e.ctrlKey, - metaKey: e.metaKey, - altKey: e.altKey, - }) - ); - }, - [imageDTO, store] - ); + const onClick = useMemo(() => buildOnClick(imageDTO.image_name, store.dispatch, store.getState), [imageDTO, store]); const onDoubleClick = useCallback>(() => { store.dispatch(imageToCompareChanged(null)); @@ -238,6 +276,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { ref={ref} src={imageDTO.thumbnail_url} w={imageDTO.width} + fallback={} objectFit="contain" maxW="full" maxH="full" @@ -253,3 +292,5 @@ export const GalleryImage = memo(({ imageDTO }: Props) => { }); GalleryImage.displayName = 'GalleryImage'; + +export const GalleryImagePlaceholder = memo(() => ); diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index 8bc48a8fce..c8f3e334a0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -2,15 +2,16 @@ import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from ' import { createSelector } from '@reduxjs/toolkit'; import { logger } from 'app/logging/logger'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; +import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching'; import type { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors'; import { - LIMIT, selectGalleryImageMinimumWidth, selectImageToCompare, selectLastSelectedImage, } from 'features/gallery/store/gallerySelectors'; import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; +import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; import type { MutableRefObject, RefObject } from 'react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; @@ -23,22 +24,15 @@ import type { VirtuosoGridHandle, } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso'; -import { useListImagesQuery } from 'services/api/endpoints/images'; -import type { ImageDTO } from 'services/api/types'; +import { imagesApi } from 'services/api/endpoints/images'; import { useDebounce } from 'use-debounce'; -import { GalleryImage } from './ImageGrid/GalleryImage'; +import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage'; import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag'; import { useGalleryImageNames } from './use-gallery-image-names'; const log = logger('gallery'); -// Constants -const VIEWPORT_BUFFER = 2048; -const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096; -const DEBOUNCE_DELAY = 500; -const SPINNER_OPACITY = 0.3; - type ListImageNamesQueryArgs = ReturnType; type GridContext = { @@ -46,58 +40,41 @@ type GridContext = { imageNames: string[]; }; -// Hook to get an image DTO from cache or trigger loading -const useImageDTOFromListQuery = ( - index: number, - imageName: string, - queryArgs: ListImageNamesQueryArgs -): ImageDTO | null => { - const { arg, options } = useMemo(() => { - const pageOffset = Math.floor(index / LIMIT) * LIMIT; - return { - arg: { - ...queryArgs, - offset: pageOffset, - limit: LIMIT, - } satisfies Parameters[0], - options: { - selectFromResult: ({ data }) => { - const imageDTO = data?.items?.[index - pageOffset] || null; - if (imageDTO && imageDTO.image_name !== imageName) { - log.warn(`Image at index ${index} does not match expected image name ${imageName}`); - return { imageDTO: null }; - } - return { imageDTO }; - }, - } satisfies Parameters[1], - }; - }, [index, queryArgs, imageName]); +const ImageAtPosition = memo(({ imageName }: { index: number; imageName: string }) => { + /* + * We rely on the useRangeBasedImageFetching to fetch all image DTOs, caching them with RTK Query. + * + * In this component, we just want to consume that cache. Unforutnately, RTK Query does not provide a way to + * subscribe to a query without triggering a new fetch. + * + * There is a hack, though: + * - https://github.com/reduxjs/redux-toolkit/discussions/4213 + * + * This essentially means "subscribe to the query once it has some data". + */ - const { imageDTO } = useListImagesQuery(arg, options); + // Use `currentData` instead of `data` to prevent a flash of previous image rendered at this index + const { currentData: imageDTO, isUninitialized } = imagesApi.endpoints.getImageDTO.useQueryState(imageName); + imagesApi.endpoints.getImageDTO.useQuerySubscription(imageName, { skip: isUninitialized }); - return imageDTO; -}; - -// Individual image component that gets its data from RTK Query cache -const ImageAtPosition = memo( - ({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => { - const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs); - - if (!imageDTO) { - return ; - } - - return ; + if (!imageDTO) { + return ; } -); + + return ; +}); ImageAtPosition.displayName = 'ImageAtPosition'; -// Memoized compute key function using image names -const computeItemKey: GridComputeItemKey = (_index, imageName, { queryArgs }) => { - return `${JSON.stringify(queryArgs)}-${imageName}`; +const computeItemKey: GridComputeItemKey = (index, imageName, { queryArgs }) => { + return `${JSON.stringify(queryArgs)}-${imageName ?? index}`; }; -// Physical DOM-based grid calculation using refs (based on working old implementation) +/** + * Calculate how many images fit in a row based on the current grid layout. + * + * TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value + * changes. Cache this calculation. + */ const getImagesPerRow = (rootEl: HTMLDivElement): number => { // Start from root and find virtuoso grid elements const gridElement = rootEl.querySelector('.virtuoso-grid-list'); @@ -124,7 +101,14 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => { return 0; } - // Use the exact calculation from the working old implementation + /** + * You might be tempted to just do some simple math like: + * const imagesPerRow = Math.floor(containerRect.width / itemRect.width); + * + * But floating point precision can cause issues with this approach, causing it to be off by 1 in some cases. + * + * Instead, we use a more robust approach that iteratively calculates how many images fit in the row. + */ let imagesPerRow = 0; let spaceUsed = 0; @@ -141,7 +125,9 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => { return Math.max(1, imagesPerRow); }; -// Check if an item at a given index is visible in the viewport +/** + * Scroll the item at the given index into view if it is not currently visible. + */ const scrollIntoView = ( index: number, rootEl: HTMLDivElement, @@ -202,6 +188,11 @@ const scrollIntoView = ( return; }; +/** + * Get the index of the image in the list of image names. + * If the image name is not found, return 0. + * If no image name is provided, return 0. + */ const getImageIndex = (imageName: string | undefined | null, imageNames: string[]) => { if (!imageName || imageNames.length === 0) { return 0; @@ -210,7 +201,9 @@ const getImageIndex = (imageName: string | undefined | null, imageNames: string[ return index >= 0 ? index : 0; }; -// Hook for keyboard navigation using physical DOM measurements +/** + * Handles keyboard navigation for the gallery. + */ const useKeyboardNavigation = ( imageNames: string[], virtuosoRef: React.RefObject, @@ -249,11 +242,12 @@ const useKeyboardNavigation = ( event.preventDefault(); + const state = getState(); const imageName = event.altKey ? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected, // we start from the last selected image - (selectImageToCompare(getState()) ?? selectLastSelectedImage(getState())) - : selectLastSelectedImage(getState()); + (selectImageToCompare(state) ?? selectLastSelectedImage(state)) + : selectLastSelectedImage(state); const currentIndex = getImageIndex(imageName, imageNames); @@ -373,6 +367,11 @@ const useKeyboardNavigation = ( }); }; +/** + * Keeps the last selected image in view when the gallery is scrolled. + * This is useful for keyboard navigation and ensuring the user can see their selection. + * It only tracks the last selected image, not the image to compare. + */ const useKeepSelectedImageInView = ( imageNames: string[], virtuosoRef: React.RefObject, @@ -397,6 +396,9 @@ const useKeepSelectedImageInView = ( }, [imageName, imageNames, rangeRef, rootRef, virtuosoRef]); }; +/** + * Handles the initialization of the overlay scrollbars for the gallery, returning the ref to the scroller element. + */ const useScrollableGallery = (rootRef: RefObject) => { const [scroller, scrollerRef] = useState(null); const [initialize, osInstance] = useOverlayScrollbars({ @@ -431,43 +433,49 @@ const useScrollableGallery = (rootRef: RefObject) => { return scrollerRef; }; -// Main gallery component export const NewGallery = memo(() => { const virtuosoRef = useRef(null); const rangeRef = useRef({ startIndex: 0, endIndex: 0 }); const rootRef = useRef(null); + const { isActiveTab } = useAutoLayoutContext(); // Get the ordered list of image names - this is our primary data source for virtualization const { queryArgs, imageNames, isLoading } = useGalleryImageNames(); + // Use range-based fetching for bulk loading image DTOs into cache based on the visible range + const { onRangeChanged } = useRangeBasedImageFetching({ + imageNames, + enabled: !isLoading && isActiveTab, + }); + useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef); useKeyboardNavigation(imageNames, virtuosoRef, rootRef); const scrollerRef = useScrollableGallery(rootRef); - // We have to keep track of the visible range for keep-selected-image-in-view functionality - const handleRangeChanged = useCallback((range: ListRange) => { - rangeRef.current = range; - }, []); - - const context = useMemo( - () => - ({ - imageNames, - queryArgs, - }) satisfies GridContext, - [imageNames, queryArgs] + /* + * We have to keep track of the visible range for keep-selected-image-in-view functionality and push the range to + * the range-based image fetching hook. + */ + const handleRangeChanged = useCallback( + (range: ListRange) => { + rangeRef.current = range; + onRangeChanged(range); + }, + [onRangeChanged] ); + const context = useMemo(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]); + // Item content function - const itemContent: GridItemContent = useCallback((index, imageName, ctx) => { - return ; + const itemContent: GridItemContent = useCallback((index, imageName) => { + return ; }, []); if (isLoading) { return ( - - - Loading gallery... + + + Loading gallery... ); } @@ -481,12 +489,13 @@ export const NewGallery = memo(() => { } return ( + // This wrapper component is necessary to initialize the overlay scrollbars! ref={virtuosoRef} context={context} data={imageNames} - increaseViewportBy={VIEWPORT_BUFFER} + increaseViewportBy={2048} itemContent={itemContent} computeItemKey={computeItemKey} components={components} @@ -503,7 +512,7 @@ export const NewGallery = memo(() => { NewGallery.displayName = 'NewGallery'; const scrollSeekConfiguration: ScrollSeekConfiguration = { - enter: (velocity) => velocity > SCROLL_SEEK_VELOCITY_THRESHOLD, + enter: (velocity) => velocity > 4096, exit: (velocity) => velocity === 0, }; @@ -518,7 +527,7 @@ const selectGridTemplateColumns = createSelector( // Grid components const ListComponent: GridComponents['List'] = forwardRef(({ context: _, ...rest }, ref) => { const _gridTemplateColumns = useAppSelector(selectGridTemplateColumns); - const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, DEBOUNCE_DELAY); + const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, 300); return ; }); diff --git a/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts b/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts index dfbe3c3775..e1ffe31190 100644 --- a/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts +++ b/invokeai/frontend/web/src/features/gallery/components/use-gallery-image-names.ts @@ -5,8 +5,8 @@ import { useGetImageNamesQuery } from 'services/api/endpoints/images'; import { useDebounce } from 'use-debounce'; const getImageNamesQueryOptions = { - selectFromResult: ({ data, isLoading, isFetching }) => ({ - imageNames: data ?? EMPTY_ARRAY, + selectFromResult: ({ currentData, isLoading, isFetching }) => ({ + imageNames: currentData?.image_names ?? EMPTY_ARRAY, isLoading, isFetching, }), @@ -14,7 +14,7 @@ const getImageNamesQueryOptions = { export const useGalleryImageNames = () => { const _queryArgs = useAppSelector(selectListImageNamesQueryArgs); - const [queryArgs] = useDebounce(_queryArgs, 500); + const [queryArgs] = useDebounce(_queryArgs, 300); const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions); return { imageNames, isLoading, isFetching, queryArgs }; }; diff --git a/invokeai/frontend/web/src/features/gallery/hooks/useRangeBasedImageFetching.ts b/invokeai/frontend/web/src/features/gallery/hooks/useRangeBasedImageFetching.ts new file mode 100644 index 0000000000..38284b7ce0 --- /dev/null +++ b/invokeai/frontend/web/src/features/gallery/hooks/useRangeBasedImageFetching.ts @@ -0,0 +1,86 @@ +import { useAppStore } from 'app/store/storeHooks'; +import { useCallback, useEffect, useState } from 'react'; +import type { ListRange } from 'react-virtuoso'; +import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images'; +import { useThrottledCallback } from 'use-debounce'; + +interface UseRangeBasedImageFetchingArgs { + imageNames: string[]; + enabled: boolean; +} + +interface UseRangeBasedImageFetchingReturn { + onRangeChanged: (range: ListRange) => void; +} + +const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => { + if (range.startIndex === range.endIndex) { + // If the start and end indices are the same, no range to fetch + return []; + } + + if (imageNames.length === 0) { + return []; + } + + const start = Math.max(0, range.startIndex); + const end = Math.min(imageNames.length - 1, range.endIndex); + + if (cachedImageNames.length === 0) { + return imageNames.slice(start, end + 1); + } + + const uncachedNames: string[] = []; + + for (let i = start; i <= end; i++) { + const imageName = imageNames[i]!; + if (!cachedImageNames.includes(imageName)) { + uncachedNames.push(imageName); + } + } + + return uncachedNames; +}; + +/** + * Hook for bulk fetching image DTOs based on the visible range from virtuoso. + * Individual image components should use `useGetImageDTOQuery(imageName)` to get their specific DTO. + * This hook ensures DTOs are bulk fetched and cached efficiently. + */ +export const useRangeBasedImageFetching = ({ + imageNames, + enabled, +}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => { + const store = useAppStore(); + const [visibleRange, setVisibleRange] = useState({ startIndex: 0, endIndex: 0 }); + const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation(); + + const fetchImages = useCallback( + (visibleRange: ListRange) => { + const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO'); + const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange); + if (uncachedNames.length === 0) { + return; + } + getImageDTOsByNames({ image_names: uncachedNames }); + }, + [getImageDTOsByNames, imageNames, store] + ); + + const throttledFetchImages = useThrottledCallback(fetchImages, 100); + + useEffect(() => { + if (!enabled) { + return; + } + throttledFetchImages(visibleRange); + }, [enabled, throttledFetchImages, imageNames, visibleRange]); + + const onRangeChanged = useCallback((range: ListRange) => { + setVisibleRange(range); + }, []); + + return { + onRangeChanged, + }; +}; diff --git a/invokeai/frontend/web/src/features/ui/layouts/auto-layout-context.tsx b/invokeai/frontend/web/src/features/ui/layouts/auto-layout-context.tsx index 77804b7356..6fb4f6c3ab 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/auto-layout-context.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/auto-layout-context.tsx @@ -1,5 +1,9 @@ +import { createSelector } from '@reduxjs/toolkit'; +import { useAppSelector } from 'app/store/storeHooks'; import type { DockviewApi, GridviewApi } from 'dockview'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; +import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { TabName } from 'features/ui/store/uiTypes'; import type { WritableAtom } from 'nanostores'; import { atom } from 'nanostores'; import type { PropsWithChildren, RefObject } from 'react'; @@ -8,6 +12,7 @@ import { createContext, memo, useCallback, useContext, useMemo, useState } from import { LEFT_PANEL_ID, LEFT_PANEL_MIN_SIZE_PX, RIGHT_PANEL_ID, RIGHT_PANEL_MIN_SIZE_PX } from './shared'; type AutoLayoutContextValue = { + isActiveTab: boolean; toggleLeftPanel: () => void; toggleRightPanel: () => void; toggleBothPanels: () => void; @@ -57,9 +62,15 @@ const activatePanel = (api: GridviewApi | DockviewApi, panelId: string) => { }; export const AutoLayoutProvider = ( - props: PropsWithChildren<{ $rootApi: WritableAtom; rootRef: RefObject }> + props: PropsWithChildren<{ + $rootApi: WritableAtom; + rootRef: RefObject; + tab: TabName; + }> ) => { - const { $rootApi, rootRef, children } = props; + const { $rootApi, rootRef, tab, children } = props; + const selectIsActiveTab = useMemo(() => createSelector(selectActiveTab, (activeTab) => activeTab === tab), [tab]); + const isActiveTab = useAppSelector(selectIsActiveTab); const $leftApi = useState(() => atom(null))[0]; const $centerApi = useState(() => atom(null))[0]; const $rightApi = useState(() => atom(null))[0]; @@ -126,6 +137,7 @@ export const AutoLayoutProvider = ( const value = useMemo( () => ({ + isActiveTab, toggleLeftPanel, toggleRightPanel, toggleBothPanels, @@ -138,6 +150,7 @@ export const AutoLayoutProvider = ( _$rightPanelApi: $rightApi, }), [ + isActiveTab, $centerApi, $leftApi, $rightApi, diff --git a/invokeai/frontend/web/src/features/ui/layouts/canvas-tab-auto-layout.tsx b/invokeai/frontend/web/src/features/ui/layouts/canvas-tab-auto-layout.tsx index 16a3a1df09..255edaa585 100644 --- a/invokeai/frontend/web/src/features/ui/layouts/canvas-tab-auto-layout.tsx +++ b/invokeai/frontend/web/src/features/ui/layouts/canvas-tab-auto-layout.tsx @@ -259,7 +259,7 @@ export const CanvasTabAutoLayout = memo(() => { useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef); return ( - + { useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef); return ( - + { useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef); return ( - + { useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef); return ( - + ({ + query: (body) => ({ + url: buildImagesUrl('images_by_names'), + method: 'POST', + body, + }), + // Don't provide cache tags - we'll manually upsert into individual getImageDTO caches + async onQueryStarted(_, { dispatch, queryFulfilled }) { + try { + const { data: imageDTOs } = await queryFulfilled; + + // Upsert each DTO into the individual image cache + const updates: Param0 = []; + for (const imageDTO of imageDTOs) { + updates.push({ + endpointName: 'getImageDTO', + arg: imageDTO.image_name, + value: imageDTO, + }); + } + dispatch(imagesApi.util.upsertQueryEntries(updates)); + } catch { + // Handle error if needed + } + }, + }), }), }); @@ -472,6 +505,7 @@ export const { useUnstarImagesMutation, useBulkDownloadImagesMutation, useGetImageNamesQuery, + useGetImageDTOsByNamesMutation, } = imagesApi; /** diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index ae843df433..44f4b5ecde 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -761,7 +761,7 @@ export type paths = { }; /** * Get Image Names - * @description Gets ordered list of all image names (starred first, then unstarred) + * @description Gets ordered list of image names with metadata for optimistic updates */ get: operations["get_image_names"]; put?: never; @@ -772,6 +772,26 @@ export type paths = { patch?: never; trace?: never; }; + "/api/v1/images/images_by_names": { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + get?: never; + put?: never; + /** + * Get Images By Names + * @description Gets image DTOs for the specified image names. Maintains order of input names. + */ + post: operations["get_images_by_names"]; + delete?: never; + options?: never; + head?: never; + patch?: never; + trace?: never; + }; "/api/v1/boards/": { parameters: { query?: never; @@ -2648,6 +2668,14 @@ export type components = { /** @description The validation run data to use for this batch. This is only used if this is a validation run. */ validation_run_data?: components["schemas"]["ValidationRunData"] | null; }; + /** Body_get_images_by_names */ + Body_get_images_by_names: { + /** + * Image Names + * @description Object containing list of image names to fetch DTOs for + */ + image_names: string[]; + }; /** Body_import_style_presets */ Body_import_style_presets: { /** @@ -10479,6 +10507,27 @@ export type components = { */ type: "img_nsfw"; }; + /** + * ImageNamesResult + * @description Response containing ordered image names with metadata for optimistic updates. + */ + ImageNamesResult: { + /** + * Image Names + * @description Ordered list of image names + */ + image_names: string[]; + /** + * Starred Count + * @description Number of starred images (when starred_first=True) + */ + starred_count: number; + /** + * Total Count + * @description Total number of images matching the query + */ + total_count: number; + }; /** * Add Image Noise * @description Add noise to an image @@ -23725,7 +23774,40 @@ export interface operations { [name: string]: unknown; }; content: { - "application/json": string[]; + "application/json": components["schemas"]["ImageNamesResult"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; + get_images_by_names: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["Body_get_images_by_names"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["ImageDTO"][]; }; }; /** @description Validation Error */ diff --git a/invokeai/frontend/web/src/services/api/types.ts b/invokeai/frontend/web/src/services/api/types.ts index 5b7c234f1d..3a520024f9 100644 --- a/invokeai/frontend/web/src/services/api/types.ts +++ b/invokeai/frontend/web/src/services/api/types.ts @@ -7,6 +7,8 @@ export type S = components['schemas']; export type ListImagesArgs = NonNullable; export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json']; +export type ImageNamesResult = S['ImageNamesResult']; + export type ListBoardsArgs = NonNullable; export type DeleteBoardResult = diff --git a/invokeai/frontend/web/src/services/api/util/optimisticUpdates.ts b/invokeai/frontend/web/src/services/api/util/optimisticUpdates.ts new file mode 100644 index 0000000000..e16ff37b52 --- /dev/null +++ b/invokeai/frontend/web/src/services/api/util/optimisticUpdates.ts @@ -0,0 +1,90 @@ +import type { ImageDTO, ImageNamesResult } from 'services/api/types'; + +/** + * Calculates the optimal insertion position for a new image in the names list. + * For starred_first=true: starred images go to position 0, unstarred go after all starred images + * For starred_first=false: all new images go to position 0 (newest first) + */ +export function calculateImageInsertionPosition( + imageDTO: ImageDTO, + starredFirst: boolean, + starredCount: number +): number { + if (!starredFirst) { + // When starred_first is false, always insert at the beginning (newest first) + return 0; + } + + // When starred_first is true + if (imageDTO.starred) { + // Starred images go at the very beginning + return 0; + } + + // Unstarred images go after all starred images + return starredCount; +} + +/** + * Optimistically inserts a new image into the ImageNamesResult at the correct position + */ +export function insertImageIntoNamesResult( + currentResult: ImageNamesResult, + imageDTO: ImageDTO, + starredFirst: boolean +): ImageNamesResult { + // Don't insert if the image is already in the list + if (currentResult.image_names.includes(imageDTO.image_name)) { + return currentResult; + } + + const insertPosition = calculateImageInsertionPosition(imageDTO, starredFirst, currentResult.starred_count); + + const newImageNames = [...currentResult.image_names]; + newImageNames.splice(insertPosition, 0, imageDTO.image_name); + + return { + image_names: newImageNames, + starred_count: starredFirst && imageDTO.starred ? currentResult.starred_count + 1 : currentResult.starred_count, + total_count: currentResult.total_count + 1, + }; +} + +/** + * Optimistically removes an image from the ImageNamesResult + */ +export function removeImageFromNamesResult( + currentResult: ImageNamesResult, + imageNameToRemove: string, + wasStarred: boolean, + starredFirst: boolean +): ImageNamesResult { + const newImageNames = currentResult.image_names.filter((name) => name !== imageNameToRemove); + + return { + image_names: newImageNames, + starred_count: starredFirst && wasStarred ? currentResult.starred_count - 1 : currentResult.starred_count, + total_count: currentResult.total_count - 1, + }; +} + +/** + * Optimistically updates an image's position in the result when its starred status changes + */ +export function updateImagePositionInNamesResult( + currentResult: ImageNamesResult, + updatedImageDTO: ImageDTO, + previouslyStarred: boolean, + starredFirst: boolean +): ImageNamesResult { + // First remove the image from its current position + const withoutImage = removeImageFromNamesResult( + currentResult, + updatedImageDTO.image_name, + previouslyStarred, + starredFirst + ); + + // Then insert it at the new correct position + return insertImageIntoNamesResult(withoutImage, updatedImageDTO, starredFirst); +} diff --git a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx index 07ecd537d7..19ae743991 100644 --- a/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx +++ b/invokeai/frontend/web/src/services/events/onInvocationComplete.tsx @@ -1,24 +1,22 @@ import { logger } from 'app/logging/logger'; -import { addAppListener } from 'app/store/middleware/listenerMiddleware'; import type { AppDispatch, AppGetState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; import { selectAutoSwitch, selectGalleryView, - selectListImagesBaseQueryArgs, + selectListImageNamesQueryArgs, selectSelectedBoardId, } from 'features/gallery/store/gallerySelectors'; import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice'; import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState'; import { isImageField, isImageFieldCollection } from 'features/nodes/types/common'; import { zNodeStatus } from 'features/nodes/types/invocation'; -import type { ApiTagDescription } from 'services/api'; import { boardsApi } from 'services/api/endpoints/boards'; import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO, S } from 'services/api/types'; import { getCategories } from 'services/api/util'; +import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates'; import { $lastProgressEvent } from 'services/events/stores'; -import stableHash from 'stable-hash'; import type { Param0 } from 'tsafe'; import { objectEntries } from 'tsafe'; import type { JsonObject } from 'type-fest'; @@ -42,9 +40,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi // For efficiency's sake, we want to minimize the number of dispatches and invalidations we do. // We'll keep track of each change we need to make and do them all at once. const boardTotalAdditions: Record = {}; - const boardTagIdsToInvalidate: Set = new Set(); - const imageListTagIdsToInvalidate: Set = new Set(); - const listImagesArg = selectListImagesBaseQueryArgs(getState()); + const listImageNamesArg = selectListImageNamesQueryArgs(getState()); for (const imageDTO of imageDTOs) { if (imageDTO.is_intermediate) { @@ -54,17 +50,6 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi const board_id = imageDTO.board_id ?? 'none'; // update the total images for the board boardTotalAdditions[board_id] = (boardTotalAdditions[board_id] || 0) + 1; - // invalidate the board tag - boardTagIdsToInvalidate.add(board_id); - // invalidate the image list tag - imageListTagIdsToInvalidate.add( - stableHash({ - ...listImagesArg, - categories: getCategories(imageDTO), - board_id, - offset: 0, - }) - ); } // Update all the board image totals at once @@ -85,16 +70,40 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi } dispatch(boardsApi.util.upsertQueryEntries(entries)); - // Invalidate all tags at once - const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({ - type: 'Board' as const, - id: boardId, - })); - const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({ - type: 'ImageList' as const, - id: imageListId, - })); - dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags])); + // Optimistically update image names lists - DTOs are already cached by getResultImageDTOs + const state = getState(); + + for (const imageDTO of imageDTOs) { + // Construct the expected query args for this image's getImageNames query + // Use the current gallery query args as base, but override board_id and categories for this specific image + const expectedQueryArgs = { + ...listImageNamesArg, + categories: getCategories(imageDTO), + board_id: imageDTO.board_id ?? 'none', + }; + + // Check if we have cached image names for this query + const cachedNamesResult = imagesApi.endpoints.getImageNames.select(expectedQueryArgs)(state); + + if (cachedNamesResult.data) { + // We have cached names - optimistically insert the new image + dispatch( + imagesApi.util.updateQueryData('getImageNames', expectedQueryArgs, (draft) => { + // Use the utility function to insert at the correct position + const updatedResult = insertImageIntoNamesResult(draft, imageDTO, expectedQueryArgs.starred_first ?? true); + + // Replace the draft contents + draft.image_names = updatedResult.image_names; + draft.starred_count = updatedResult.starred_count; + draft.total_count = updatedResult.total_count; + }) + ); + } + // If no cached data, we don't need to do anything - there's no list to update + } + + // No need to invalidate tags since we're doing optimistic updates + // Board totals are already updated above via upsertQueryEntries const autoSwitch = selectAutoSwitch(getState()); @@ -112,63 +121,27 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi const { image_name } = lastImageDTO; const board_id = lastImageDTO.board_id ?? 'none'; - /** - * Auto-switch needs a bit of care to avoid race conditions - we need to invalidate the appropriate image list - * query cache, and only after it has loaded, select the new image. - */ - const queryArgs = { - ...listImagesArg, - categories: getCategories(lastImageDTO), - board_id, - offset: 0, - }; + // With optimistic updates, we can immediately switch to the new image + const selectedBoardId = selectSelectedBoardId(getState()); - dispatch( - addAppListener({ - predicate: (action) => { - if (!imagesApi.endpoints.listImages.matchFulfilled(action)) { - return false; - } - - if (stableHash(action.meta.arg.originalArgs) !== stableHash(queryArgs)) { - return false; - } - - return true; - }, - effect: (_action, { getState, dispatch, unsubscribe }) => { - // This is a one-time listener - we always unsubscribe after the first match - unsubscribe(); - - // Auto-switch may have been disabled while we were waiting for the query to resolve - bail if so - const autoSwitch = selectAutoSwitch(getState()); - if (!autoSwitch) { - return; - } - - const selectedBoardId = selectSelectedBoardId(getState()); - - // If the image is from a different board, switch to that board & select the image - otherwise just select the - // image. This implicitly changes the view to 'images' if it was not already. - if (board_id !== selectedBoardId) { - dispatch( - boardIdSelected({ - boardId: board_id, - selectedImageName: image_name, - }) - ); - } else { - // Ensure we are on the 'images' gallery view - that's where this image will be displayed - const galleryView = selectGalleryView(getState()); - if (galleryView !== 'images') { - dispatch(galleryViewChanged('images')); - } - // Else just select the image, no need to switch boards - dispatch(imageSelected(lastImageDTO.image_name)); - } - }, - }) - ); + // If the image is from a different board, switch to that board & select the image - otherwise just select the + // image. This implicitly changes the view to 'images' if it was not already. + if (board_id !== selectedBoardId) { + dispatch( + boardIdSelected({ + boardId: board_id, + selectedImageName: image_name, + }) + ); + } else { + // Ensure we are on the 'images' gallery view - that's where this image will be displayed + const galleryView = selectGalleryView(getState()); + if (galleryView !== 'images') { + dispatch(galleryViewChanged('images')); + } + // Select the image immediately since we've optimistically updated the cache + dispatch(imageSelected(lastImageDTO.image_name)); + } }; const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise => {