mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 23:35:12 -05:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import React, { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
@@ -18,264 +19,113 @@ const ImageSkeleton = memo(() => <Skeleton w="full" h="full" />);
|
||||
|
||||
ImageSkeleton.displayName = 'ImageSkeleton';
|
||||
|
||||
// Hook to manage image data for virtual scrolling
|
||||
// Hook to manage position calculations and image access
|
||||
const useVirtualImageData = () => {
|
||||
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
|
||||
// Get total counts for position mapping
|
||||
const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs);
|
||||
|
||||
// Cache for loaded image ranges
|
||||
const [loadedRanges, setLoadedRanges] = useState<Map<string, ImageDTO[]>>(new Map());
|
||||
|
||||
// Calculate position mappings
|
||||
const positionInfo = useMemo(() => {
|
||||
if (!counts) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const result = {
|
||||
return {
|
||||
totalCount: counts.total_count,
|
||||
starredCount: counts.starred_count ?? 0,
|
||||
unstarredCount: counts.unstarred_count ?? 0,
|
||||
starredEnd: (counts.starred_count ?? 0) - 1,
|
||||
};
|
||||
|
||||
return result;
|
||||
}, [counts]);
|
||||
|
||||
// Clear cache when search parameters change
|
||||
React.useEffect(() => {
|
||||
setLoadedRanges(new Map());
|
||||
}, [queryArgs.board_id, queryArgs.search_term, queryArgs.categories]);
|
||||
|
||||
// Return flag to indicate when search parameters have changed
|
||||
const searchParamsChanged = useMemo(() => queryArgs, [queryArgs]);
|
||||
|
||||
// Function to generate cache key for a range
|
||||
const getRangeKey = useCallback((collection: 'starred' | 'unstarred', offset: number, limit: number) => {
|
||||
return `${collection}-${offset}-${limit}`;
|
||||
}, []);
|
||||
|
||||
// Function to get images for a specific position range
|
||||
const getImagesForRange = useCallback(
|
||||
(startIndex: number, endIndex: number) => {
|
||||
// Function to get query params for a specific position
|
||||
const getQueryParamsForPosition = useCallback(
|
||||
(index: number) => {
|
||||
if (!positionInfo) {
|
||||
return [];
|
||||
return null;
|
||||
}
|
||||
|
||||
const requestedImages: (ImageDTO | null)[] = new Array(endIndex - startIndex + 1).fill(null);
|
||||
const rangesToLoad: Array<{
|
||||
collection: 'starred' | 'unstarred';
|
||||
offset: number;
|
||||
limit: number;
|
||||
targetStartIndex: number;
|
||||
}> = [];
|
||||
|
||||
for (let i = startIndex; i <= endIndex; i++) {
|
||||
const relativeIndex = i - startIndex;
|
||||
|
||||
// Handle case where there are no starred images
|
||||
if (positionInfo.starredCount === 0 || i >= positionInfo.starredCount) {
|
||||
// This position is in the unstarred collection
|
||||
const unstarredOffset = i - positionInfo.starredCount;
|
||||
const rangeKey = getRangeKey('unstarred', Math.floor(unstarredOffset / 50) * 50, 50);
|
||||
const cachedRange = loadedRanges.get(rangeKey);
|
||||
|
||||
if (cachedRange) {
|
||||
const imageIndex = unstarredOffset % 50;
|
||||
if (imageIndex < cachedRange.length) {
|
||||
requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null;
|
||||
}
|
||||
} else {
|
||||
// Need to load this range
|
||||
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
|
||||
rangesToLoad.push({
|
||||
collection: 'unstarred',
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
targetStartIndex: i,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// This position is in the starred collection
|
||||
const starredOffset = i;
|
||||
const rangeKey = getRangeKey('starred', Math.floor(starredOffset / 50) * 50, 50);
|
||||
const cachedRange = loadedRanges.get(rangeKey);
|
||||
|
||||
if (cachedRange) {
|
||||
const imageIndex = starredOffset % 50;
|
||||
if (imageIndex < cachedRange.length) {
|
||||
requestedImages[relativeIndex] = cachedRange[imageIndex] ?? null;
|
||||
}
|
||||
} else {
|
||||
// Need to load this range
|
||||
const rangeOffset = Math.floor(starredOffset / 50) * 50;
|
||||
rangesToLoad.push({
|
||||
collection: 'starred',
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
targetStartIndex: i,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (positionInfo.starredCount === 0 || index >= positionInfo.starredCount) {
|
||||
// This position is in the unstarred collection
|
||||
const unstarredOffset = index - positionInfo.starredCount;
|
||||
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
|
||||
return {
|
||||
collection: 'unstarred' as const,
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
imageIndex: unstarredOffset % 50,
|
||||
};
|
||||
} else {
|
||||
// This position is in the starred collection
|
||||
const rangeOffset = Math.floor(index / 50) * 50;
|
||||
return {
|
||||
collection: 'starred' as const,
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
imageIndex: index % 50,
|
||||
};
|
||||
}
|
||||
|
||||
return { images: requestedImages, rangesToLoad };
|
||||
},
|
||||
[positionInfo, loadedRanges, getRangeKey]
|
||||
[positionInfo]
|
||||
);
|
||||
|
||||
return {
|
||||
positionInfo,
|
||||
countsLoading,
|
||||
getImagesForRange,
|
||||
setLoadedRanges,
|
||||
loadedRanges,
|
||||
searchParamsChanged,
|
||||
getQueryParamsForPosition,
|
||||
queryArgs,
|
||||
};
|
||||
};
|
||||
|
||||
// Component to handle loading image ranges
|
||||
const ImageRangeLoader = memo(
|
||||
({
|
||||
collection,
|
||||
offset,
|
||||
limit,
|
||||
onDataLoaded,
|
||||
}: {
|
||||
collection: 'starred' | 'unstarred';
|
||||
offset: number;
|
||||
limit: number;
|
||||
onDataLoaded: (key: string, images: ImageDTO[]) => void;
|
||||
}) => {
|
||||
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
// Hook to get image data for a specific position using RTK Query cache
|
||||
const useImageAtPosition = (index: number) => {
|
||||
const { getQueryParamsForPosition, queryArgs } = useVirtualImageData();
|
||||
|
||||
const { data } = useGetImageCollectionQuery({
|
||||
collection,
|
||||
offset,
|
||||
limit,
|
||||
...queryArgs,
|
||||
});
|
||||
const queryParams = getQueryParamsForPosition(index);
|
||||
|
||||
// Update cache when data is loaded - use useEffect to avoid state update during render
|
||||
React.useEffect(() => {
|
||||
if (data?.items) {
|
||||
const key = `${collection}-${offset}-${limit}`;
|
||||
onDataLoaded(key, data.items);
|
||||
}
|
||||
}, [data, collection, offset, limit, onDataLoaded]);
|
||||
const { data } = useGetImageCollectionQuery(
|
||||
queryParams
|
||||
? {
|
||||
collection: queryParams.collection,
|
||||
offset: queryParams.offset,
|
||||
limit: queryParams.limit,
|
||||
...queryArgs,
|
||||
}
|
||||
: skipToken
|
||||
);
|
||||
|
||||
if (!queryParams || !data?.items) {
|
||||
return null;
|
||||
}
|
||||
);
|
||||
|
||||
ImageRangeLoader.displayName = 'ImageRangeLoader';
|
||||
return data.items[queryParams.imageIndex] || null;
|
||||
};
|
||||
|
||||
// Component to render a single image at a position
|
||||
const ImageAtPosition = memo(({ index }: { index: number }) => {
|
||||
const image = useImageAtPosition(index);
|
||||
|
||||
if (image) {
|
||||
return <ImagePlaceholder image={image} />;
|
||||
}
|
||||
|
||||
return <ImageSkeleton />;
|
||||
});
|
||||
|
||||
ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
export const NewGallery = memo(() => {
|
||||
const { positionInfo, countsLoading, getImagesForRange, setLoadedRanges, searchParamsChanged } =
|
||||
useVirtualImageData();
|
||||
const [activeRangeLoaders, setActiveRangeLoaders] = useState<Set<string>>(new Set());
|
||||
|
||||
// Force initial range loading when position info becomes available
|
||||
const [hasInitiallyLoaded, setHasInitiallyLoaded] = useState(false);
|
||||
|
||||
// Reset hasInitiallyLoaded when search parameters change
|
||||
React.useEffect(() => {
|
||||
setHasInitiallyLoaded(false);
|
||||
setActiveRangeLoaders(new Set());
|
||||
}, [searchParamsChanged]);
|
||||
|
||||
// Use useEffect for initial load to avoid state updates during render
|
||||
React.useEffect(() => {
|
||||
if (positionInfo && !hasInitiallyLoaded) {
|
||||
// Force initial load of first 100 positions to ensure we see both starred and unstarred
|
||||
const initialResult = getImagesForRange(0, Math.min(99, positionInfo.totalCount - 1));
|
||||
if (!Array.isArray(initialResult)) {
|
||||
const { rangesToLoad } = initialResult;
|
||||
rangesToLoad.forEach((rangeInfo) => {
|
||||
const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`;
|
||||
if (!activeRangeLoaders.has(key)) {
|
||||
setActiveRangeLoaders((prev) => new Set(prev).add(key));
|
||||
}
|
||||
});
|
||||
}
|
||||
setHasInitiallyLoaded(true);
|
||||
}
|
||||
}, [positionInfo, hasInitiallyLoaded, getImagesForRange, activeRangeLoaders]);
|
||||
|
||||
// Handle range changes from virtuoso
|
||||
const handleRangeChanged = useCallback(
|
||||
(range: { startIndex: number; endIndex: number }) => {
|
||||
if (!positionInfo) {
|
||||
return;
|
||||
}
|
||||
|
||||
const result = getImagesForRange(range.startIndex, range.endIndex);
|
||||
if (!Array.isArray(result)) {
|
||||
const { rangesToLoad } = result;
|
||||
|
||||
// Start loading any missing ranges
|
||||
rangesToLoad.forEach((rangeInfo) => {
|
||||
const key = `${rangeInfo.collection}-${rangeInfo.offset}-${rangeInfo.limit}`;
|
||||
if (!activeRangeLoaders.has(key)) {
|
||||
setActiveRangeLoaders((prev) => new Set(prev).add(key));
|
||||
}
|
||||
});
|
||||
}
|
||||
},
|
||||
[positionInfo, getImagesForRange, activeRangeLoaders]
|
||||
);
|
||||
|
||||
// Handle when range data is loaded
|
||||
const handleDataLoaded = useCallback(
|
||||
(key: string, images: ImageDTO[]) => {
|
||||
setLoadedRanges((prev) => new Map(prev).set(key, images));
|
||||
setActiveRangeLoaders((prev) => {
|
||||
const next = new Set(prev);
|
||||
next.delete(key);
|
||||
return next;
|
||||
});
|
||||
},
|
||||
[setLoadedRanges]
|
||||
);
|
||||
|
||||
const computeItemKey = useCallback(
|
||||
(index: number) => {
|
||||
const result = getImagesForRange(index, index);
|
||||
if (Array.isArray(result)) {
|
||||
return `loading-${index}`;
|
||||
}
|
||||
const { images } = result;
|
||||
const image = images[0];
|
||||
return image ? `image-${index}-${image.image_name}` : `skeleton-${index}`;
|
||||
},
|
||||
[getImagesForRange]
|
||||
);
|
||||
const { positionInfo, countsLoading } = useVirtualImageData();
|
||||
|
||||
// Render item at specific index
|
||||
const itemContent = useCallback(
|
||||
(index: number) => {
|
||||
if (!positionInfo) {
|
||||
return <ImageSkeleton />;
|
||||
}
|
||||
const itemContent = useCallback((index: number) => {
|
||||
return <ImageAtPosition index={index} />;
|
||||
}, []);
|
||||
|
||||
const result = getImagesForRange(index, index);
|
||||
if (Array.isArray(result)) {
|
||||
return <ImageSkeleton />;
|
||||
}
|
||||
|
||||
const { images } = result;
|
||||
const image = images[0];
|
||||
|
||||
if (image) {
|
||||
return <ImagePlaceholder image={image} />;
|
||||
}
|
||||
|
||||
return <ImageSkeleton />;
|
||||
},
|
||||
[positionInfo, getImagesForRange]
|
||||
);
|
||||
// Compute item key using position index - let RTK Query handle the caching
|
||||
const computeItemKey = useCallback((index: number) => `position-${index}`, []);
|
||||
|
||||
if (countsLoading) {
|
||||
return (
|
||||
@@ -296,25 +146,10 @@ export const NewGallery = memo(() => {
|
||||
|
||||
return (
|
||||
<Box height="100%" width="100%">
|
||||
{/* Render active range loaders */}
|
||||
{Array.from(activeRangeLoaders).map((key) => {
|
||||
const [collection, offset, limit] = key.split('-');
|
||||
return (
|
||||
<ImageRangeLoader
|
||||
key={key}
|
||||
collection={collection as 'starred' | 'unstarred'}
|
||||
offset={parseInt(offset ?? '0', 10)}
|
||||
limit={parseInt(limit ?? '50', 10)}
|
||||
onDataLoaded={handleDataLoaded}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Virtualized grid */}
|
||||
<VirtuosoGrid
|
||||
totalCount={positionInfo.totalCount}
|
||||
overscan={200}
|
||||
rangeChanged={handleRangeChanged}
|
||||
itemContent={itemContent}
|
||||
style={style}
|
||||
computeItemKey={computeItemKey}
|
||||
|
||||
Reference in New Issue
Block a user