From 8327d8677483a95c2c0fa5b4e236dc721e66e158 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 24 Jun 2025 21:31:37 +1000 Subject: [PATCH] refactor: gallery scroll (improved impl) --- .../gallery/components/NewGallery.tsx | 235 ++++++++++-------- 1 file changed, 137 insertions(+), 98 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index b4838ba142..ebc5bc1615 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -1,15 +1,17 @@ import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; +import { logger } from 'app/logging/logger'; import { useAppSelector } from 'app/store/storeHooks'; import { selectGalleryImageMinimumWidth, selectImageCollectionQueryArgs, } from 'features/gallery/store/gallerySelectors'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; -import { memo, useEffect, useMemo, useRef, useState } from 'react'; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import type { GridComponents, GridComputeItemKey, GridItemContent, + ListRange, ScrollSeekConfiguration, VirtuosoGridHandle, } from 'react-virtuoso'; @@ -18,12 +20,16 @@ import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery, useGetImageNamesQuery, + useLazyGetImageCollectionQuery, } from 'services/api/endpoints/images'; -import type { ImageCategory, SQLiteDirection } from 'services/api/types'; +import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types'; +import { objectEntries } from 'tsafe'; import { useDebounce } from 'use-debounce'; import { GalleryImage } from './ImageGrid/GalleryImage'; +const log = logger('gallery'); + // Type for image collection query arguments type ImageCollectionQueryArgs = { board_id?: string; @@ -33,18 +39,21 @@ type ImageCollectionQueryArgs = { is_intermediate: boolean; }; -// Types -type Collection = 'starred' | 'unstarred'; - -interface PositionInfo { - collection: Collection; - offset: number; - itemIndex: number; -} - // Constants const RANGE_SIZE = 50; +type GridContext = { + queryArgs: ImageCollectionQueryArgs; + imageNames: string[]; + starredCount: number; +}; + +type PositionInfo = { + collection: 'starred' | 'unstarred'; + offset: number; + itemIndex: number; +}; + // Helper to calculate which collection and range an index belongs to const getPositionInfo = (index: number, starredCount: number): PositionInfo => { if (index < starredCount) { @@ -67,68 +76,63 @@ const getPositionInfo = (index: number, starredCount: number): PositionInfo => { } }; -// Hook to get image at a specific position -const useImageAtPosition = (index: number, starredCount: number, queryArgs: ImageCollectionQueryArgs) => { - const positionInfo = useMemo(() => getPositionInfo(index, starredCount), [index, starredCount]); +// Hook to get image DTO from batched collection data +const useImageFromBatch = ( + imageName: string, + index: number, + starredCount: number, + queryArgs: ImageCollectionQueryArgs +): ImageDTO | null => { + const { arg, options } = useMemo(() => { + const positionInfo = getPositionInfo(index, starredCount); - const arg = useMemo( - () => - ({ - collection: positionInfo.collection, - offset: positionInfo.offset, - limit: RANGE_SIZE, - ...queryArgs, - }) satisfies Parameters[0], - [positionInfo.collection, positionInfo.offset, queryArgs] - ); + const arg = { + collection: positionInfo.collection, + offset: positionInfo.offset, + limit: RANGE_SIZE, + ...queryArgs, + } satisfies Parameters[0]; - const options = useMemo( - () => - ({ - selectFromResult: ({ data }) => { - if (!data) { - return { imageDTO: null }; - } else { - return { - imageDTO: data.items[positionInfo.itemIndex] || null, - }; - } - }, - }) satisfies Parameters[1], - [positionInfo.itemIndex] - ); + const options = { + selectFromResult: ({ data }) => { + const imageDTO = data?.items?.[positionInfo.itemIndex] || null; + if (imageDTO && imageDTO.image_name !== imageName) { + log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`); + } + return { imageDTO }; + }, + } satisfies Parameters[1]; + + return { arg, options }; + }, [imageName, index, queryArgs, starredCount]); const { imageDTO } = useGetImageCollectionQuery(arg, options); return imageDTO; }; -type ImageAtPositionProps = { - index: number; - starredCount: number; - queryArgs: ImageCollectionQueryArgs; -}; +// Individual image component that gets its data from batched requests +const ImageAtPosition = memo( + ({ + imageName, + index, + starredCount, + queryArgs, + }: { + imageName: string; + index: number; + starredCount: number; + queryArgs: ImageCollectionQueryArgs; + }) => { + const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs); -type GridContext = { - queryArgs: ImageCollectionQueryArgs; - counts: { - starred_count: number; - unstarred_count: number; - total_count: number; - }; -}; + if (!imageDTO) { + return ; + } -// Individual image component -const ImageAtPosition = memo(({ index, starredCount, queryArgs }: ImageAtPositionProps) => { - const imageDTO = useImageAtPosition(index, starredCount, queryArgs); - - if (!imageDTO) { - return ; + return ; } - - return ; -}); - +); ImageAtPosition.displayName = 'ImageAtPosition'; export const useDebouncedImageCollectionQueryArgs = () => { @@ -137,31 +141,52 @@ export const useDebouncedImageCollectionQueryArgs = () => { return queryArgs; }; -const getImageCollectionCountsOptions = { - selectFromResult: ({ data, isLoading }) => ({ - counts: data - ? { - starred_count: data.starred_count, - unstarred_count: data.unstarred_count, - total_count: data.starred_count + data.unstarred_count, - } - : { - starred_count: 0, - unstarred_count: 0, - total_count: 0, - }, - isLoading, - }), -} satisfies Parameters[1]; - -// Memoized item content function -const itemContent: GridItemContent = (index, _item, { queryArgs, counts }) => { - return ; +// Memoized item content function that uses image names as data but batches requests +const itemContent: GridItemContent = (index, imageName, { queryArgs, starredCount }) => { + if (!imageName) { + return ; + } + return ; }; -// Memoized compute key function -const computeItemKey: GridComputeItemKey = (index, _item, { queryArgs }) => { - return `${JSON.stringify(queryArgs)}-${index}`; +// Memoized compute key function using image names +const computeItemKey: GridComputeItemKey = (index, imageName, { queryArgs }) => { + return `${JSON.stringify(queryArgs)}-${imageName || index}`; +}; + +// Hook to prefetch ranges based on visible area +const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => { + const [triggerGetImageCollection] = useLazyGetImageCollectionQuery(); + + const prefetchRange = useCallback( + (startIndex: number, endIndex: number) => { + const ranges = { + starred: new Set(), + unstarred: new Set(), + }; + + // Collect all unique ranges needed for the visible area + for (let i = startIndex; i <= endIndex; i++) { + const positionInfo = getPositionInfo(i, starredCount); + ranges[positionInfo.collection].add(positionInfo.offset); + } + + // Trigger queries for each unique range + for (const [collection, offsets] of objectEntries(ranges)) { + for (const offset of offsets) { + triggerGetImageCollection({ + collection, + offset, + limit: RANGE_SIZE, + ...queryArgs, + }); + } + } + }, + [starredCount, queryArgs, triggerGetImageCollection] + ); + + return prefetchRange; }; // Main gallery component @@ -169,18 +194,21 @@ export const NewGallery = memo(() => { const queryArgs = useDebouncedImageCollectionQueryArgs(); const virtuosoRef = useRef(null); - const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions); + // Get the ordered list of image names - this is our primary data source + const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs); - // Load image names for selection operations - this is lightweight and ensures - // selection operations work even before image data is fully loaded - useGetImageNamesQuery(queryArgs); + // Get starred count for position calculations + const { data: counts } = useGetImageCollectionCountsQuery(queryArgs); + const starredCount = counts?.starred_count ?? 0; + + const prefetchRange = usePrefetchRanges(starredCount, queryArgs); // Reset scroll position when query parameters change useEffect(() => { - if (virtuosoRef.current && counts.total_count > 0) { + if (virtuosoRef.current && imageNames.length > 0) { virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' }); } - }, [counts.total_count, queryArgs]); + }, [queryArgs, imageNames.length]); const rootRef = useRef(null); const [scroller, setScroller] = useState(null); @@ -213,13 +241,22 @@ export const NewGallery = memo(() => { }; }, [scroller, initialize, osInstance]); + // Handle range changes to prefetch data for visible + buffer areas + const handleRangeChanged = useCallback( + (range: ListRange) => { + prefetchRange(range.startIndex, range.endIndex); + }, + [prefetchRange] + ); + const context = useMemo( () => ({ - counts, + imageNames, queryArgs, + starredCount, }) satisfies GridContext, - [counts, queryArgs] + [imageNames, queryArgs, starredCount] ); if (isLoading) { @@ -231,7 +268,7 @@ export const NewGallery = memo(() => { ); } - if (counts.total_count === 0) { + if (imageNames.length === 0) { return ( No images found @@ -241,17 +278,19 @@ export const NewGallery = memo(() => { return ( - + ref={virtuosoRef} context={context} - totalCount={counts.total_count} - increaseViewportBy={1024} + totalCount={imageNames.length} + data={imageNames} + increaseViewportBy={2048} itemContent={itemContent} computeItemKey={computeItemKey} components={components} style={style} scrollerRef={setScroller} scrollSeekConfiguration={scrollSeekConfiguration} + rangeChanged={handleRangeChanged} /> ); @@ -260,7 +299,7 @@ export const NewGallery = memo(() => { NewGallery.displayName = 'NewGallery'; const scrollSeekConfiguration: ScrollSeekConfiguration = { - enter: (velocity) => velocity > 1000, + enter: (velocity) => velocity > 2048, exit: (velocity) => velocity === 0, };