refactor: gallery scroll (improved impl)

This commit is contained in:
psychedelicious
2025-06-24 21:31:37 +10:00
parent c8254710e6
commit 8327d86774

View File

@@ -1,15 +1,17 @@
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { import {
selectGalleryImageMinimumWidth, selectGalleryImageMinimumWidth,
selectImageCollectionQueryArgs, selectImageCollectionQueryArgs,
} from 'features/gallery/store/gallerySelectors'; } from 'features/gallery/store/gallerySelectors';
import { useOverlayScrollbars } from 'overlayscrollbars-react'; import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useEffect, useMemo, useRef, useState } from 'react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { import type {
GridComponents, GridComponents,
GridComputeItemKey, GridComputeItemKey,
GridItemContent, GridItemContent,
ListRange,
ScrollSeekConfiguration, ScrollSeekConfiguration,
VirtuosoGridHandle, VirtuosoGridHandle,
} from 'react-virtuoso'; } from 'react-virtuoso';
@@ -18,12 +20,16 @@ import {
useGetImageCollectionCountsQuery, useGetImageCollectionCountsQuery,
useGetImageCollectionQuery, useGetImageCollectionQuery,
useGetImageNamesQuery, useGetImageNamesQuery,
useLazyGetImageCollectionQuery,
} from 'services/api/endpoints/images'; } from 'services/api/endpoints/images';
import type { ImageCategory, SQLiteDirection } from 'services/api/types'; import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
import { objectEntries } from 'tsafe';
import { useDebounce } from 'use-debounce'; import { useDebounce } from 'use-debounce';
import { GalleryImage } from './ImageGrid/GalleryImage'; import { GalleryImage } from './ImageGrid/GalleryImage';
const log = logger('gallery');
// Type for image collection query arguments // Type for image collection query arguments
type ImageCollectionQueryArgs = { type ImageCollectionQueryArgs = {
board_id?: string; board_id?: string;
@@ -33,18 +39,21 @@ type ImageCollectionQueryArgs = {
is_intermediate: boolean; is_intermediate: boolean;
}; };
// Types
type Collection = 'starred' | 'unstarred';
interface PositionInfo {
collection: Collection;
offset: number;
itemIndex: number;
}
// Constants // Constants
const RANGE_SIZE = 50; const RANGE_SIZE = 50;
type GridContext = {
queryArgs: ImageCollectionQueryArgs;
imageNames: string[];
starredCount: number;
};
type PositionInfo = {
collection: 'starred' | 'unstarred';
offset: number;
itemIndex: number;
};
// Helper to calculate which collection and range an index belongs to // Helper to calculate which collection and range an index belongs to
const getPositionInfo = (index: number, starredCount: number): PositionInfo => { const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
if (index < starredCount) { if (index < starredCount) {
@@ -67,68 +76,63 @@ const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
} }
}; };
// Hook to get image at a specific position // Hook to get image DTO from batched collection data
const useImageAtPosition = (index: number, starredCount: number, queryArgs: ImageCollectionQueryArgs) => { const useImageFromBatch = (
const positionInfo = useMemo(() => getPositionInfo(index, starredCount), [index, starredCount]); imageName: string,
index: number,
starredCount: number,
queryArgs: ImageCollectionQueryArgs
): ImageDTO | null => {
const { arg, options } = useMemo(() => {
const positionInfo = getPositionInfo(index, starredCount);
const arg = useMemo( const arg = {
() => collection: positionInfo.collection,
({ offset: positionInfo.offset,
collection: positionInfo.collection, limit: RANGE_SIZE,
offset: positionInfo.offset, ...queryArgs,
limit: RANGE_SIZE, } satisfies Parameters<typeof useGetImageCollectionQuery>[0];
...queryArgs,
}) satisfies Parameters<typeof useGetImageCollectionQuery>[0],
[positionInfo.collection, positionInfo.offset, queryArgs]
);
const options = useMemo( const options = {
() => selectFromResult: ({ data }) => {
({ const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
selectFromResult: ({ data }) => { if (imageDTO && imageDTO.image_name !== imageName) {
if (!data) { log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
return { imageDTO: null }; }
} else { return { imageDTO };
return { },
imageDTO: data.items[positionInfo.itemIndex] || null, } satisfies Parameters<typeof useGetImageCollectionQuery>[1];
};
} return { arg, options };
}, }, [imageName, index, queryArgs, starredCount]);
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1],
[positionInfo.itemIndex]
);
const { imageDTO } = useGetImageCollectionQuery(arg, options); const { imageDTO } = useGetImageCollectionQuery(arg, options);
return imageDTO; return imageDTO;
}; };
type ImageAtPositionProps = { // Individual image component that gets its data from batched requests
index: number; const ImageAtPosition = memo(
starredCount: number; ({
queryArgs: ImageCollectionQueryArgs; imageName,
}; index,
starredCount,
queryArgs,
}: {
imageName: string;
index: number;
starredCount: number;
queryArgs: ImageCollectionQueryArgs;
}) => {
const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs);
type GridContext = { if (!imageDTO) {
queryArgs: ImageCollectionQueryArgs; return <Skeleton w="full" h="full" />;
counts: { }
starred_count: number;
unstarred_count: number;
total_count: number;
};
};
// Individual image component return <GalleryImage imageDTO={imageDTO} />;
const ImageAtPosition = memo(({ index, starredCount, queryArgs }: ImageAtPositionProps) => {
const imageDTO = useImageAtPosition(index, starredCount, queryArgs);
if (!imageDTO) {
return <Skeleton w="full" h="full" />;
} }
);
return <GalleryImage imageDTO={imageDTO} />;
});
ImageAtPosition.displayName = 'ImageAtPosition'; ImageAtPosition.displayName = 'ImageAtPosition';
export const useDebouncedImageCollectionQueryArgs = () => { export const useDebouncedImageCollectionQueryArgs = () => {
@@ -137,31 +141,52 @@ export const useDebouncedImageCollectionQueryArgs = () => {
return queryArgs; return queryArgs;
}; };
const getImageCollectionCountsOptions = { // Memoized item content function that uses image names as data but batches requests
selectFromResult: ({ data, isLoading }) => ({ const itemContent: GridItemContent<string, GridContext> = (index, imageName, { queryArgs, starredCount }) => {
counts: data if (!imageName) {
? { return <Skeleton w="full" h="full" />;
starred_count: data.starred_count, }
unstarred_count: data.unstarred_count, return <ImageAtPosition imageName={imageName} index={index} starredCount={starredCount} queryArgs={queryArgs} />;
total_count: data.starred_count + data.unstarred_count,
}
: {
starred_count: 0,
unstarred_count: 0,
total_count: 0,
},
isLoading,
}),
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
// Memoized item content function
const itemContent: GridItemContent<null, GridContext> = (index, _item, { queryArgs, counts }) => {
return <ImageAtPosition index={index} starredCount={counts.starred_count} queryArgs={queryArgs} />;
}; };
// Memoized compute key function // Memoized compute key function using image names
const computeItemKey: GridComputeItemKey<null, GridContext> = (index, _item, { queryArgs }) => { const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${index}`; return `${JSON.stringify(queryArgs)}-${imageName || index}`;
};
// Hook to prefetch ranges based on visible area
const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
const prefetchRange = useCallback(
(startIndex: number, endIndex: number) => {
const ranges = {
starred: new Set<number>(),
unstarred: new Set<number>(),
};
// Collect all unique ranges needed for the visible area
for (let i = startIndex; i <= endIndex; i++) {
const positionInfo = getPositionInfo(i, starredCount);
ranges[positionInfo.collection].add(positionInfo.offset);
}
// Trigger queries for each unique range
for (const [collection, offsets] of objectEntries(ranges)) {
for (const offset of offsets) {
triggerGetImageCollection({
collection,
offset,
limit: RANGE_SIZE,
...queryArgs,
});
}
}
},
[starredCount, queryArgs, triggerGetImageCollection]
);
return prefetchRange;
}; };
// Main gallery component // Main gallery component
@@ -169,18 +194,21 @@ export const NewGallery = memo(() => {
const queryArgs = useDebouncedImageCollectionQueryArgs(); const queryArgs = useDebouncedImageCollectionQueryArgs();
const virtuosoRef = useRef<VirtuosoGridHandle>(null); const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions); // Get the ordered list of image names - this is our primary data source
const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs);
// Load image names for selection operations - this is lightweight and ensures // Get starred count for position calculations
// selection operations work even before image data is fully loaded const { data: counts } = useGetImageCollectionCountsQuery(queryArgs);
useGetImageNamesQuery(queryArgs); const starredCount = counts?.starred_count ?? 0;
const prefetchRange = usePrefetchRanges(starredCount, queryArgs);
// Reset scroll position when query parameters change // Reset scroll position when query parameters change
useEffect(() => { useEffect(() => {
if (virtuosoRef.current && counts.total_count > 0) { if (virtuosoRef.current && imageNames.length > 0) {
virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' }); virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' });
} }
}, [counts.total_count, queryArgs]); }, [queryArgs, imageNames.length]);
const rootRef = useRef<HTMLDivElement>(null); const rootRef = useRef<HTMLDivElement>(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null); const [scroller, setScroller] = useState<HTMLElement | null>(null);
@@ -213,13 +241,22 @@ export const NewGallery = memo(() => {
}; };
}, [scroller, initialize, osInstance]); }, [scroller, initialize, osInstance]);
// Handle range changes to prefetch data for visible + buffer areas
const handleRangeChanged = useCallback(
(range: ListRange) => {
prefetchRange(range.startIndex, range.endIndex);
},
[prefetchRange]
);
const context = useMemo( const context = useMemo(
() => () =>
({ ({
counts, imageNames,
queryArgs, queryArgs,
starredCount,
}) satisfies GridContext, }) satisfies GridContext,
[counts, queryArgs] [imageNames, queryArgs, starredCount]
); );
if (isLoading) { if (isLoading) {
@@ -231,7 +268,7 @@ export const NewGallery = memo(() => {
); );
} }
if (counts.total_count === 0) { if (imageNames.length === 0) {
return ( return (
<Flex height="100%" alignItems="center" justifyContent="center"> <Flex height="100%" alignItems="center" justifyContent="center">
<Text color="base.300">No images found</Text> <Text color="base.300">No images found</Text>
@@ -241,17 +278,19 @@ export const NewGallery = memo(() => {
return ( return (
<Box data-overlayscrollbars-initialize="" ref={rootRef} w="full" h="full"> <Box data-overlayscrollbars-initialize="" ref={rootRef} w="full" h="full">
<VirtuosoGrid<null, GridContext> <VirtuosoGrid<string, GridContext>
ref={virtuosoRef} ref={virtuosoRef}
context={context} context={context}
totalCount={counts.total_count} totalCount={imageNames.length}
increaseViewportBy={1024} data={imageNames}
increaseViewportBy={2048}
itemContent={itemContent} itemContent={itemContent}
computeItemKey={computeItemKey} computeItemKey={computeItemKey}
components={components} components={components}
style={style} style={style}
scrollerRef={setScroller} scrollerRef={setScroller}
scrollSeekConfiguration={scrollSeekConfiguration} scrollSeekConfiguration={scrollSeekConfiguration}
rangeChanged={handleRangeChanged}
/> />
</Box> </Box>
); );
@@ -260,7 +299,7 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery'; NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = { const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => velocity > 1000, enter: (velocity) => velocity > 2048,
exit: (velocity) => velocity === 0, exit: (velocity) => velocity === 0,
}; };