mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 22:35:27 -05:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@@ -1,15 +1,17 @@
|
||||
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectImageCollectionQueryArgs,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import { memo, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type {
|
||||
GridComponents,
|
||||
GridComputeItemKey,
|
||||
GridItemContent,
|
||||
ListRange,
|
||||
ScrollSeekConfiguration,
|
||||
VirtuosoGridHandle,
|
||||
} from 'react-virtuoso';
|
||||
@@ -18,12 +20,16 @@ import {
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
|
||||
import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
import { GalleryImage } from './ImageGrid/GalleryImage';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
// Type for image collection query arguments
|
||||
type ImageCollectionQueryArgs = {
|
||||
board_id?: string;
|
||||
@@ -33,18 +39,21 @@ type ImageCollectionQueryArgs = {
|
||||
is_intermediate: boolean;
|
||||
};
|
||||
|
||||
// Types
|
||||
type Collection = 'starred' | 'unstarred';
|
||||
|
||||
interface PositionInfo {
|
||||
collection: Collection;
|
||||
offset: number;
|
||||
itemIndex: number;
|
||||
}
|
||||
|
||||
// Constants
|
||||
const RANGE_SIZE = 50;
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
imageNames: string[];
|
||||
starredCount: number;
|
||||
};
|
||||
|
||||
type PositionInfo = {
|
||||
collection: 'starred' | 'unstarred';
|
||||
offset: number;
|
||||
itemIndex: number;
|
||||
};
|
||||
|
||||
// Helper to calculate which collection and range an index belongs to
|
||||
const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
|
||||
if (index < starredCount) {
|
||||
@@ -67,68 +76,63 @@ const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
|
||||
}
|
||||
};
|
||||
|
||||
// Hook to get image at a specific position
|
||||
const useImageAtPosition = (index: number, starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
|
||||
const positionInfo = useMemo(() => getPositionInfo(index, starredCount), [index, starredCount]);
|
||||
// Hook to get image DTO from batched collection data
|
||||
const useImageFromBatch = (
|
||||
imageName: string,
|
||||
index: number,
|
||||
starredCount: number,
|
||||
queryArgs: ImageCollectionQueryArgs
|
||||
): ImageDTO | null => {
|
||||
const { arg, options } = useMemo(() => {
|
||||
const positionInfo = getPositionInfo(index, starredCount);
|
||||
|
||||
const arg = useMemo(
|
||||
() =>
|
||||
({
|
||||
collection: positionInfo.collection,
|
||||
offset: positionInfo.offset,
|
||||
limit: RANGE_SIZE,
|
||||
...queryArgs,
|
||||
}) satisfies Parameters<typeof useGetImageCollectionQuery>[0],
|
||||
[positionInfo.collection, positionInfo.offset, queryArgs]
|
||||
);
|
||||
const arg = {
|
||||
collection: positionInfo.collection,
|
||||
offset: positionInfo.offset,
|
||||
limit: RANGE_SIZE,
|
||||
...queryArgs,
|
||||
} satisfies Parameters<typeof useGetImageCollectionQuery>[0];
|
||||
|
||||
const options = useMemo(
|
||||
() =>
|
||||
({
|
||||
selectFromResult: ({ data }) => {
|
||||
if (!data) {
|
||||
return { imageDTO: null };
|
||||
} else {
|
||||
return {
|
||||
imageDTO: data.items[positionInfo.itemIndex] || null,
|
||||
};
|
||||
}
|
||||
},
|
||||
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1],
|
||||
[positionInfo.itemIndex]
|
||||
);
|
||||
const options = {
|
||||
selectFromResult: ({ data }) => {
|
||||
const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
|
||||
if (imageDTO && imageDTO.image_name !== imageName) {
|
||||
log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
|
||||
}
|
||||
return { imageDTO };
|
||||
},
|
||||
} satisfies Parameters<typeof useGetImageCollectionQuery>[1];
|
||||
|
||||
return { arg, options };
|
||||
}, [imageName, index, queryArgs, starredCount]);
|
||||
|
||||
const { imageDTO } = useGetImageCollectionQuery(arg, options);
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
type ImageAtPositionProps = {
|
||||
index: number;
|
||||
starredCount: number;
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
};
|
||||
// Individual image component that gets its data from batched requests
|
||||
const ImageAtPosition = memo(
|
||||
({
|
||||
imageName,
|
||||
index,
|
||||
starredCount,
|
||||
queryArgs,
|
||||
}: {
|
||||
imageName: string;
|
||||
index: number;
|
||||
starredCount: number;
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
}) => {
|
||||
const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs);
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
counts: {
|
||||
starred_count: number;
|
||||
unstarred_count: number;
|
||||
total_count: number;
|
||||
};
|
||||
};
|
||||
if (!imageDTO) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
}
|
||||
|
||||
// Individual image component
|
||||
const ImageAtPosition = memo(({ index, starredCount, queryArgs }: ImageAtPositionProps) => {
|
||||
const imageDTO = useImageAtPosition(index, starredCount, queryArgs);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
return <GalleryImage imageDTO={imageDTO} />;
|
||||
}
|
||||
|
||||
return <GalleryImage imageDTO={imageDTO} />;
|
||||
});
|
||||
|
||||
);
|
||||
ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
export const useDebouncedImageCollectionQueryArgs = () => {
|
||||
@@ -137,31 +141,52 @@ export const useDebouncedImageCollectionQueryArgs = () => {
|
||||
return queryArgs;
|
||||
};
|
||||
|
||||
const getImageCollectionCountsOptions = {
|
||||
selectFromResult: ({ data, isLoading }) => ({
|
||||
counts: data
|
||||
? {
|
||||
starred_count: data.starred_count,
|
||||
unstarred_count: data.unstarred_count,
|
||||
total_count: data.starred_count + data.unstarred_count,
|
||||
}
|
||||
: {
|
||||
starred_count: 0,
|
||||
unstarred_count: 0,
|
||||
total_count: 0,
|
||||
},
|
||||
isLoading,
|
||||
}),
|
||||
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
|
||||
|
||||
// Memoized item content function
|
||||
const itemContent: GridItemContent<null, GridContext> = (index, _item, { queryArgs, counts }) => {
|
||||
return <ImageAtPosition index={index} starredCount={counts.starred_count} queryArgs={queryArgs} />;
|
||||
// Memoized item content function that uses image names as data but batches requests
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, imageName, { queryArgs, starredCount }) => {
|
||||
if (!imageName) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
}
|
||||
return <ImageAtPosition imageName={imageName} index={index} starredCount={starredCount} queryArgs={queryArgs} />;
|
||||
};
|
||||
|
||||
// Memoized compute key function
|
||||
const computeItemKey: GridComputeItemKey<null, GridContext> = (index, _item, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${index}`;
|
||||
// Memoized compute key function using image names
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${imageName || index}`;
|
||||
};
|
||||
|
||||
// Hook to prefetch ranges based on visible area
|
||||
const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
|
||||
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
|
||||
|
||||
const prefetchRange = useCallback(
|
||||
(startIndex: number, endIndex: number) => {
|
||||
const ranges = {
|
||||
starred: new Set<number>(),
|
||||
unstarred: new Set<number>(),
|
||||
};
|
||||
|
||||
// Collect all unique ranges needed for the visible area
|
||||
for (let i = startIndex; i <= endIndex; i++) {
|
||||
const positionInfo = getPositionInfo(i, starredCount);
|
||||
ranges[positionInfo.collection].add(positionInfo.offset);
|
||||
}
|
||||
|
||||
// Trigger queries for each unique range
|
||||
for (const [collection, offsets] of objectEntries(ranges)) {
|
||||
for (const offset of offsets) {
|
||||
triggerGetImageCollection({
|
||||
collection,
|
||||
offset,
|
||||
limit: RANGE_SIZE,
|
||||
...queryArgs,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
[starredCount, queryArgs, triggerGetImageCollection]
|
||||
);
|
||||
|
||||
return prefetchRange;
|
||||
};
|
||||
|
||||
// Main gallery component
|
||||
@@ -169,18 +194,21 @@ export const NewGallery = memo(() => {
|
||||
const queryArgs = useDebouncedImageCollectionQueryArgs();
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
|
||||
const { counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs, getImageCollectionCountsOptions);
|
||||
// Get the ordered list of image names - this is our primary data source
|
||||
const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs);
|
||||
|
||||
// Load image names for selection operations - this is lightweight and ensures
|
||||
// selection operations work even before image data is fully loaded
|
||||
useGetImageNamesQuery(queryArgs);
|
||||
// Get starred count for position calculations
|
||||
const { data: counts } = useGetImageCollectionCountsQuery(queryArgs);
|
||||
const starredCount = counts?.starred_count ?? 0;
|
||||
|
||||
const prefetchRange = usePrefetchRanges(starredCount, queryArgs);
|
||||
|
||||
// Reset scroll position when query parameters change
|
||||
useEffect(() => {
|
||||
if (virtuosoRef.current && counts.total_count > 0) {
|
||||
if (virtuosoRef.current && imageNames.length > 0) {
|
||||
virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' });
|
||||
}
|
||||
}, [counts.total_count, queryArgs]);
|
||||
}, [queryArgs, imageNames.length]);
|
||||
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
@@ -213,13 +241,22 @@ export const NewGallery = memo(() => {
|
||||
};
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
// Handle range changes to prefetch data for visible + buffer areas
|
||||
const handleRangeChanged = useCallback(
|
||||
(range: ListRange) => {
|
||||
prefetchRange(range.startIndex, range.endIndex);
|
||||
},
|
||||
[prefetchRange]
|
||||
);
|
||||
|
||||
const context = useMemo(
|
||||
() =>
|
||||
({
|
||||
counts,
|
||||
imageNames,
|
||||
queryArgs,
|
||||
starredCount,
|
||||
}) satisfies GridContext,
|
||||
[counts, queryArgs]
|
||||
[imageNames, queryArgs, starredCount]
|
||||
);
|
||||
|
||||
if (isLoading) {
|
||||
@@ -231,7 +268,7 @@ export const NewGallery = memo(() => {
|
||||
);
|
||||
}
|
||||
|
||||
if (counts.total_count === 0) {
|
||||
if (imageNames.length === 0) {
|
||||
return (
|
||||
<Flex height="100%" alignItems="center" justifyContent="center">
|
||||
<Text color="base.300">No images found</Text>
|
||||
@@ -241,17 +278,19 @@ export const NewGallery = memo(() => {
|
||||
|
||||
return (
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} w="full" h="full">
|
||||
<VirtuosoGrid<null, GridContext>
|
||||
<VirtuosoGrid<string, GridContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
totalCount={counts.total_count}
|
||||
increaseViewportBy={1024}
|
||||
totalCount={imageNames.length}
|
||||
data={imageNames}
|
||||
increaseViewportBy={2048}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
style={style}
|
||||
scrollerRef={setScroller}
|
||||
scrollSeekConfiguration={scrollSeekConfiguration}
|
||||
rangeChanged={handleRangeChanged}
|
||||
/>
|
||||
</Box>
|
||||
);
|
||||
@@ -260,7 +299,7 @@ export const NewGallery = memo(() => {
|
||||
NewGallery.displayName = 'NewGallery';
|
||||
|
||||
const scrollSeekConfiguration: ScrollSeekConfiguration = {
|
||||
enter: (velocity) => velocity > 1000,
|
||||
enter: (velocity) => velocity > 2048,
|
||||
exit: (velocity) => velocity === 0,
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user