diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index 9f6807d278..985a8a6427 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -1,7 +1,8 @@ import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; +import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors'; -import React, { memo, useCallback, useMemo, useState } from 'react'; +import { memo, useCallback, useMemo } from 'react'; import { VirtuosoGrid } from 'react-virtuoso'; import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; @@ -18,264 +19,113 @@ const ImageSkeleton = memo(() => ); ImageSkeleton.displayName = 'ImageSkeleton'; -// Hook to manage image data for virtual scrolling +// Hook to manage position calculations and image access const useVirtualImageData = () => { const queryArgs = useAppSelector(selectImageCollectionQueryArgs); // Get total counts for position mapping const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs); - // Cache for loaded image ranges - const [loadedRanges, setLoadedRanges] = useState>(new Map()); - // Calculate position mappings const positionInfo = useMemo(() => { if (!counts) { return null; } - const result = { + return { totalCount: counts.total_count, starredCount: counts.starred_count ?? 0, unstarredCount: counts.unstarred_count ?? 0, starredEnd: (counts.starred_count ?? 0) - 1, }; - - return result; }, [counts]); - // Clear cache when search parameters change - React.useEffect(() => { - setLoadedRanges(new Map()); - }, [queryArgs.board_id, queryArgs.search_term, queryArgs.categories]); - - // Return flag to indicate when search parameters have changed - const searchParamsChanged = useMemo(() => queryArgs, [queryArgs]); - - // Function to generate cache key for a range - const getRangeKey = useCallback((collection: 'starred' | 'unstarred', offset: number, limit: number) => { - return `${collection}-${offset}-${limit}`; - }, []); - - // Function to get images for a specific position range - const getImagesForRange = useCallback( - (startIndex: number, endIndex: number) => { + // Function to get query params for a specific position + const getQueryParamsForPosition = useCallback( + (index: number) => { if (!positionInfo) { - return []; + return null; } - const requestedImages: (ImageDTO | null)[] = new Array(endIndex - startIndex + 1).fill(null); - const rangesToLoad: Array<{ - collection: 'starred' | 'unstarred'; - offset: number; - limit: number; - targetStartIndex: number; - }> = []; - - for (let i = startIndex; i <= endIndex; i++) { - const relativeIndex = i - startIndex; - - // Handle case where there are no starred images - if (positionInfo.starredCount === 0 || i >= positionInfo.starredCount) { - // This position is in the unstarred collection - const unstarredOffset = i - positionInfo.starredCount; - const rangeKey = getRangeKey('unstarred', Math.floor(unstarredOffset / 50) * 50, 50); - const cachedRange = loadedRanges.get(rangeKey); - - if (cachedRange) { - const imageIndex = unstarredOffset % 50; - if (imageIndex < cachedRange.length) { - requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null; - } - } else { - // Need to load this range - const rangeOffset = Math.floor(unstarredOffset / 50) * 50; - rangesToLoad.push({ - collection: 'unstarred', - offset: rangeOffset, - limit: 50, - targetStartIndex: i, - }); - } - } else { - // This position is in the starred collection - const starredOffset = i; - const rangeKey = getRangeKey('starred', Math.floor(starredOffset / 50) * 50, 50); - const cachedRange = loadedRanges.get(rangeKey); - - if (cachedRange) { - const imageIndex = starredOffset % 50; - if (imageIndex < cachedRange.length) { - requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null; - } - } else { - // Need to load this range - const rangeOffset = Math.floor(starredOffset / 50) * 50; - rangesToLoad.push({ - collection: 'starred', - offset: rangeOffset, - limit: 50, - targetStartIndex: i, - }); - } - } + if (positionInfo.starredCount === 0 || index >= positionInfo.starredCount) { + // This position is in the unstarred collection + const unstarredOffset = index - positionInfo.starredCount; + const rangeOffset = Math.floor(unstarredOffset / 50) * 50; + return { + collection: 'unstarred' as const, + offset: rangeOffset, + limit: 50, + imageIndex: unstarredOffset % 50, + }; + } else { + // This position is in the starred collection + const rangeOffset = Math.floor(index / 50) * 50; + return { + collection: 'starred' as const, + offset: rangeOffset, + limit: 50, + imageIndex: index % 50, + }; } - - return { images: requestedImages, rangesToLoad }; }, - [positionInfo, loadedRanges, getRangeKey] + [positionInfo] ); return { positionInfo, countsLoading, - getImagesForRange, - setLoadedRanges, - loadedRanges, - searchParamsChanged, + getQueryParamsForPosition, + queryArgs, }; }; -// Component to handle loading image ranges -const ImageRangeLoader = memo( - ({ - collection, - offset, - limit, - onDataLoaded, - }: { - collection: 'starred' | 'unstarred'; - offset: number; - limit: number; - onDataLoaded: (key: string, images: ImageDTO[]) => void; - }) => { - const queryArgs = useAppSelector(selectImageCollectionQueryArgs); +// Hook to get image data for a specific position using RTK Query cache +const useImageAtPosition = (index: number) => { + const { getQueryParamsForPosition, queryArgs } = useVirtualImageData(); - const { data } = useGetImageCollectionQuery({ - collection, - offset, - limit, - ...queryArgs, - }); + const queryParams = getQueryParamsForPosition(index); - // Update cache when data is loaded - use useEffect to avoid state update during render - React.useEffect(() => { - if (data?.items) { - const key = `${collection}-${offset}-${limit}`; - onDataLoaded(key, data.items); - } - }, [data, collection, offset, limit, onDataLoaded]); + const { data } = useGetImageCollectionQuery( + queryParams + ? { + collection: queryParams.collection, + offset: queryParams.offset, + limit: queryParams.limit, + ...queryArgs, + } + : skipToken + ); + if (!queryParams || !data?.items) { return null; } -); -ImageRangeLoader.displayName = 'ImageRangeLoader'; + return data.items[queryParams.imageIndex] || null; +}; + +// Component to render a single image at a position +const ImageAtPosition = memo(({ index }: { index: number }) => { + const image = useImageAtPosition(index); + + if (image) { + return ; + } + + return ; +}); + +ImageAtPosition.displayName = 'ImageAtPosition'; export const NewGallery = memo(() => { - const { positionInfo, countsLoading, getImagesForRange, setLoadedRanges, searchParamsChanged } = - useVirtualImageData(); - const [activeRangeLoaders, setActiveRangeLoaders] = useState>(new Set()); - - // Force initial range loading when position info becomes available - const [hasInitiallyLoaded, setHasInitiallyLoaded] = useState(false); - - // Reset hasInitiallyLoaded when search parameters change - React.useEffect(() => { - setHasInitiallyLoaded(false); - setActiveRangeLoaders(new Set()); - }, [searchParamsChanged]); - - // Use useEffect for initial load to avoid state updates during render - React.useEffect(() => { - if (positionInfo && !hasInitiallyLoaded) { - // Force initial load of first 100 positions to ensure we see both starred and unstarred - const initialResult = getImagesForRange(0, Math.min(99, positionInfo.totalCount - 1)); - if (!Array.isArray(initialResult)) { - const { rangesToLoad } = initialResult; - rangesToLoad.forEach((rangeInfo) => { - const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`; - if (!activeRangeLoaders.has(key)) { - setActiveRangeLoaders((prev) => new Set(prev).add(key)); - } - }); - } - setHasInitiallyLoaded(true); - } - }, [positionInfo, hasInitiallyLoaded, getImagesForRange, activeRangeLoaders]); - - // Handle range changes from virtuoso - const handleRangeChanged = useCallback( - (range: { startIndex: number; endIndex: number }) => { - if (!positionInfo) { - return; - } - - const result = getImagesForRange(range.startIndex, range.endIndex); - if (!Array.isArray(result)) { - const { rangesToLoad } = result; - - // Start loading any missing ranges - rangesToLoad.forEach((rangeInfo) => { - const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`; - if (!activeRangeLoaders.has(key)) { - setActiveRangeLoaders((prev) => new Set(prev).add(key)); - } - }); - } - }, - [positionInfo, getImagesForRange, activeRangeLoaders] - ); - - // Handle when range data is loaded - const handleDataLoaded = useCallback( - (key: string, images: ImageDTO[]) => { - setLoadedRanges((prev) => new Map(prev).set(key, images)); - setActiveRangeLoaders((prev) => { - const next = new Set(prev); - next.delete(key); - return next; - }); - }, - [setLoadedRanges] - ); - - const computeItemKey = useCallback( - (index: number) => { - const result = getImagesForRange(index, index); - if (Array.isArray(result)) { - return `loading-${index}`; - } - const { images } = result; - const image = images[0]; - return image ? `image-${index}-${image.image_name}` : `skeleton-${index}`; - }, - [getImagesForRange] - ); + const { positionInfo, countsLoading } = useVirtualImageData(); // Render item at specific index - const itemContent = useCallback( - (index: number) => { - if (!positionInfo) { - return ; - } + const itemContent = useCallback((index: number) => { + return ; + }, []); - const result = getImagesForRange(index, index); - if (Array.isArray(result)) { - return ; - } - - const { images } = result; - const image = images[0]; - - if (image) { - return ; - } - - return ; - }, - [positionInfo, getImagesForRange] - ); + // Compute item key using position index - let RTK Query handle the caching + const computeItemKey = useCallback((index: number) => `position-${index}`, []); if (countsLoading) { return ( @@ -296,25 +146,10 @@ export const NewGallery = memo(() => { return ( - {/* Render active range loaders */} - {Array.from(activeRangeLoaders).map((key) => { - const [collection, offset, limit] = key.split('-'); - return ( - - ); - })} - {/* Virtualized grid */}