refactor: gallery scroll (improved impl)

This commit is contained in:
psychedelicious
2025-06-24 21:31:37 +10:00
parent c8254710e6
commit 8327d86774

View File

@@ -1,15 +1,17 @@
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectGalleryImageMinimumWidth,
selectImageCollectionQueryArgs,
} from 'features/gallery/store/gallerySelectors';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useEffect, useMemo, useRef, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type {
GridComponents,
GridComputeItemKey,
GridItemContent,
ListRange,
ScrollSeekConfiguration,
VirtuosoGridHandle,
} from 'react-virtuoso';
@@ -18,12 +20,16 @@ import {
useGetImageCollectionCountsQuery,
useGetImageCollectionQuery,
useGetImageNamesQuery,
useLazyGetImageCollectionQuery,
} from 'services/api/endpoints/images';
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
import { objectEntries } from 'tsafe';
import { useDebounce } from 'use-debounce';
import { GalleryImage } from './ImageGrid/GalleryImage';
const log = logger('gallery');
// Type for image collection query arguments
type ImageCollectionQueryArgs = {
board_id?: string;
@@ -33,18 +39,21 @@ type ImageCollectionQueryArgs = {
is_intermediate: boolean;
};
// Types
type Collection = 'starred' | 'unstarred';
interface PositionInfo {
collection: Collection;
offset: number;
itemIndex: number;
}
// Constants
const RANGE_SIZE = 50;
type GridContext = {
queryArgs: ImageCollectionQueryArgs;
imageNames: string[];
starredCount: number;
};
type PositionInfo = {
collection: 'starred' | 'unstarred';
offset: number;
itemIndex: number;
};
// Helper to calculate which collection and range an index belongs to
const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
if (index < starredCount) {
@@ -67,68 +76,63 @@ const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
}
};
// Hook to get image at a specific position
const useImageAtPosition = (index: number, starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
const positionInfo = useMemo(() => getPositionInfo(index, starredCount), [index, starredCount]);
// Hook to get image DTO from batched collection data
const useImageFromBatch = (
imageName: string,
index: number,
starredCount: number,
queryArgs: ImageCollectionQueryArgs
): ImageDTO | null => {
const { arg, options } = useMemo(() => {
const positionInfo = getPositionInfo(index, starredCount);
const arg = useMemo(
() =>
({
collection: positionInfo.collection,
offset: positionInfo.offset,
limit: RANGE_SIZE,
...queryArgs,
}) satisfies Parameters<typeof useGetImageCollectionQuery>[0],
[positionInfo.collection, positionInfo.offset, queryArgs]
);
const arg = {
collection: positionInfo.collection,
offset: positionInfo.offset,
limit: RANGE_SIZE,
...queryArgs,
} satisfies Parameters<typeof useGetImageCollectionQuery>[0];
const options = useMemo(
() =>
({
selectFromResult: ({ data }) => {
if (!data) {
return { imageDTO: null };
} else {
return {
imageDTO: data.items[positionInfo.itemIndex] || null,
};
}
},
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1],
[positionInfo.itemIndex]
);
const options = {
selectFromResult: ({ data }) => {
const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
if (imageDTO && imageDTO.image_name !== imageName) {
log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
}
return { imageDTO };
},
} satisfies Parameters<typeof useGetImageCollectionQuery>[1];
return { arg, options };
}, [imageName, index, queryArgs, starredCount]);
const { imageDTO } = useGetImageCollectionQuery(arg, options);
return imageDTO;
};
type ImageAtPositionProps = {
index: number;
starredCount: number;
queryArgs: ImageCollectionQueryArgs;
};
// Individual image component that gets its data from batched requests
const ImageAtPosition = memo(
({
imageName,
index,
starredCount,
queryArgs,
}: {
imageName: string;
index: number;
starredCount: number;
queryArgs: ImageCollectionQueryArgs;
}) => {
const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs);
type GridContext = {
queryArgs: ImageCollectionQueryArgs;
counts: {
starred_count: number;
unstarred_count: number;
total_count: number;
};
};
if (!imageDTO) {
return <Skeleton w="full" h="full" />;
}
// Individual image component
const ImageAtPosition = memo(({ index, starredCount, queryArgs }: ImageAtPositionProps) => {
const imageDTO = useImageAtPosition(index, starredCount, queryArgs);
if (!imageDTO) {
return <Skeleton w="full" h="full" />;
return <GalleryImage imageDTO={imageDTO} />;
}
return <GalleryImage imageDTO={imageDTO} />;
});
);
ImageAtPosition.displayName = 'ImageAtPosition';
export const useDebouncedImageCollectionQueryArgs = () => {
@@ -137,31 +141,52 @@ export const useDebouncedImageCollectionQueryArgs = () => {
return queryArgs;
};
const getImageCollectionCountsOptions = {
selectFromResult: ({ data, isLoading }) => ({
counts: data
? {
starred_count: data.starred_count,
unstarred_count: data.unstarred_count,
total_count: data.starred_count + data.unstarred_count,
}
: {
starred_count: 0,
unstarred_count: 0,
total_count: 0,
},
isLoading,
}),
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
// Memoized item content function
const itemContent: GridItemContent<null, GridContext> = (index, _item, { queryArgs, counts }) => {
return <ImageAtPosition index={index} starredCount={counts.starred_count} queryArgs={queryArgs} />;
// Memoized item content function that uses image names as data but batches requests
const itemContent: GridItemContent<string, GridContext> = (index, imageName, { queryArgs, starredCount }) => {
if (!imageName) {
return <Skeleton w="full" h="full" />;
}
return <ImageAtPosition imageName={imageName} index={index} starredCount={starredCount} queryArgs={queryArgs} />;
};
// Memoized compute key function
const computeItemKey: GridComputeItemKey<null, GridContext> = (index, _item, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${index}`;
// Memoized compute key function using image names
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${imageName || index}`;
};
// Hook to prefetch ranges based on visible area
const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
const prefetchRange = useCallback(
(startIndex: number, endIndex: number) => {
const ranges = {
starred: new Set<number>(),
unstarred: new Set<number>(),
};
// Collect all unique ranges needed for the visible area
for (let i = startIndex; i <= endIndex; i++) {
const positionInfo = getPositionInfo(i, starredCount);
ranges[positionInfo.collection].add(positionInfo.offset);
}
// Trigger queries for each unique range
for (const [collection, offsets] of objectEntries(ranges)) {
for (const offset of offsets) {
triggerGetImageCollection({
collection,
offset,
limit: RANGE_SIZE,
...queryArgs,
});
}
}
},
[starredCount, queryArgs, triggerGetImageCollection]
);
return prefetchRange;
};
// Main gallery component
@@ -169,18 +194,21 @@ export const NewGallery = memo(() => {
const queryArgs = useDebouncedImageCollectionQueryArgs();
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions);
// Get the ordered list of image names - this is our primary data source
const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs);
// Load image names for selection operations - this is lightweight and ensures
// selection operations work even before image data is fully loaded
useGetImageNamesQuery(queryArgs);
// Get starred count for position calculations
const { data: counts } = useGetImageCollectionCountsQuery(queryArgs);
const starredCount = counts?.starred_count ?? 0;
const prefetchRange = usePrefetchRanges(starredCount, queryArgs);
// Reset scroll position when query parameters change
useEffect(() => {
if (virtuosoRef.current && counts.total_count > 0) {
if (virtuosoRef.current && imageNames.length > 0) {
virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' });
}
}, [counts.total_count, queryArgs]);
}, [queryArgs, imageNames.length]);
const rootRef = useRef<HTMLDivElement>(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
@@ -213,13 +241,22 @@ export const NewGallery = memo(() => {
};
}, [scroller, initialize, osInstance]);
// Handle range changes to prefetch data for visible + buffer areas
const handleRangeChanged = useCallback(
(range: ListRange) => {
prefetchRange(range.startIndex, range.endIndex);
},
[prefetchRange]
);
const context = useMemo(
() =>
({
counts,
imageNames,
queryArgs,
starredCount,
}) satisfies GridContext,
[counts, queryArgs]
[imageNames, queryArgs, starredCount]
);
if (isLoading) {
@@ -231,7 +268,7 @@ export const NewGallery = memo(() => {
);
}
if (counts.total_count === 0) {
if (imageNames.length === 0) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Text color="base.300">No images found</Text>
@@ -241,17 +278,19 @@ export const NewGallery = memo(() => {
return (
<Box data-overlayscrollbars-initialize="" ref={rootRef} w="full" h="full">
<VirtuosoGrid<null, GridContext>
<VirtuosoGrid<string, GridContext>
ref={virtuosoRef}
context={context}
totalCount={counts.total_count}
increaseViewportBy={1024}
totalCount={imageNames.length}
data={imageNames}
increaseViewportBy={2048}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
style={style}
scrollerRef={setScroller}
scrollSeekConfiguration={scrollSeekConfiguration}
rangeChanged={handleRangeChanged}
/>
</Box>
);
@@ -260,7 +299,7 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => velocity > 1000,
enter: (velocity) => velocity > 2048,
exit: (velocity) => velocity === 0,
};