mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@@ -170,8 +170,6 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
// serializableCheck: false,
|
||||
// immutableCheck: false,
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
})
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { MutableRefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type {
|
||||
GridComponents,
|
||||
@@ -18,120 +19,64 @@ import type {
|
||||
VirtuosoGridHandle,
|
||||
} from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import {
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
|
||||
import { objectEntries } from 'tsafe';
|
||||
import { useGetImageNamesQuery, useListImagesQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, ListImagesArgs } from 'services/api/types';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
import { GalleryImage } from './ImageGrid/GalleryImage';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
// Type for image collection query arguments
|
||||
type ImageCollectionQueryArgs = {
|
||||
board_id?: string;
|
||||
categories?: ImageCategory[];
|
||||
search_term?: string;
|
||||
order_dir?: SQLiteDirection;
|
||||
is_intermediate: boolean;
|
||||
};
|
||||
|
||||
// Constants
|
||||
const RANGE_SIZE = 50;
|
||||
const PAGE_SIZE = 100;
|
||||
const VIEWPORT_BUFFER = 2048;
|
||||
const SCROLL_SEEK_VELOCITY_THRESHOLD = 2048;
|
||||
const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096;
|
||||
const DEBOUNCE_DELAY = 500;
|
||||
const GRID_GAP = 2;
|
||||
const GRID_GAP = 1;
|
||||
const SPINNER_OPACITY = 0.3;
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
queryArgs: ListImagesArgs;
|
||||
imageNames: string[];
|
||||
starredCount: number;
|
||||
};
|
||||
|
||||
type PositionInfo = {
|
||||
collection: 'starred' | 'unstarred';
|
||||
offset: number;
|
||||
itemIndex: number;
|
||||
export const useDebouncedImageCollectionQueryArgs = () => {
|
||||
const _galleryQueryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
const [queryArgs] = useDebounce(_galleryQueryArgs, DEBOUNCE_DELAY);
|
||||
return queryArgs;
|
||||
};
|
||||
|
||||
// Helper to calculate which collection and range an index belongs to
|
||||
const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
|
||||
if (index < starredCount) {
|
||||
// Starred collection
|
||||
const offset = Math.floor(index / RANGE_SIZE) * RANGE_SIZE;
|
||||
return {
|
||||
collection: 'starred',
|
||||
offset,
|
||||
itemIndex: index - offset,
|
||||
};
|
||||
} else {
|
||||
// Unstarred collection
|
||||
const unstarredIndex = index - starredCount;
|
||||
const offset = Math.floor(unstarredIndex / RANGE_SIZE) * RANGE_SIZE;
|
||||
return {
|
||||
collection: 'unstarred',
|
||||
offset,
|
||||
itemIndex: unstarredIndex - offset,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
// Hook to get image DTO from batched collection data
|
||||
const useImageFromBatch = (
|
||||
imageName: string,
|
||||
index: number,
|
||||
starredCount: number,
|
||||
queryArgs: ImageCollectionQueryArgs
|
||||
): ImageDTO | null => {
|
||||
// Hook to get an image DTO from cache or trigger loading
|
||||
const useImageDTOFromListQuery = (index: number, imageName: string, queryArgs: ListImagesArgs): ImageDTO | null => {
|
||||
const { arg, options } = useMemo(() => {
|
||||
const positionInfo = getPositionInfo(index, starredCount);
|
||||
const pageOffset = Math.floor(index / PAGE_SIZE) * PAGE_SIZE;
|
||||
return {
|
||||
arg: {
|
||||
...queryArgs,
|
||||
offset: pageOffset,
|
||||
limit: PAGE_SIZE,
|
||||
} satisfies Parameters<typeof useListImagesQuery>[0],
|
||||
options: {
|
||||
selectFromResult: ({ data }) => {
|
||||
const imageDTO = data?.items?.[index - pageOffset] || null;
|
||||
if (imageDTO && imageDTO.image_name !== imageName) {
|
||||
log.warn(`Image at index ${index} does not match expected image name ${imageName}`);
|
||||
}
|
||||
return { imageDTO };
|
||||
},
|
||||
} satisfies Parameters<typeof useListImagesQuery>[1],
|
||||
};
|
||||
}, [index, queryArgs, imageName]);
|
||||
|
||||
const arg = {
|
||||
collection: positionInfo.collection,
|
||||
offset: positionInfo.offset,
|
||||
limit: RANGE_SIZE,
|
||||
...queryArgs,
|
||||
} satisfies Parameters<typeof useGetImageCollectionQuery>[0];
|
||||
|
||||
const options = {
|
||||
selectFromResult: ({ data }) => {
|
||||
const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
|
||||
if (imageDTO && imageDTO.image_name !== imageName) {
|
||||
log.warn(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
|
||||
}
|
||||
return { imageDTO };
|
||||
},
|
||||
} satisfies Parameters<typeof useGetImageCollectionQuery>[1];
|
||||
|
||||
return { arg, options };
|
||||
}, [imageName, index, queryArgs, starredCount]);
|
||||
|
||||
const { imageDTO } = useGetImageCollectionQuery(arg, options);
|
||||
const { imageDTO } = useListImagesQuery(arg, options);
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
// Individual image component that gets its data from batched requests
|
||||
// Individual image component that gets its data from RTK Query cache
|
||||
const ImageAtPosition = memo(
|
||||
({
|
||||
imageName,
|
||||
index,
|
||||
starredCount,
|
||||
queryArgs,
|
||||
}: {
|
||||
imageName: string;
|
||||
index: number;
|
||||
starredCount: number;
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
}) => {
|
||||
const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs);
|
||||
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImagesArgs }) => {
|
||||
const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
@@ -142,58 +87,9 @@ const ImageAtPosition = memo(
|
||||
);
|
||||
ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
export const useDebouncedImageCollectionQueryArgs = () => {
|
||||
const _queryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY);
|
||||
return queryArgs;
|
||||
};
|
||||
|
||||
// Memoized item content function that uses image names as data but batches requests
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, imageName, { queryArgs, starredCount }) => {
|
||||
if (!imageName) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
}
|
||||
return <ImageAtPosition imageName={imageName} index={index} starredCount={starredCount} queryArgs={queryArgs} />;
|
||||
};
|
||||
|
||||
// Memoized compute key function using image names
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${imageName || index}`;
|
||||
};
|
||||
|
||||
// Hook to prefetch ranges based on visible area
|
||||
const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
|
||||
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
|
||||
|
||||
const prefetchRange = useCallback(
|
||||
(startIndex: number, endIndex: number) => {
|
||||
const ranges = {
|
||||
starred: new Set<number>(),
|
||||
unstarred: new Set<number>(),
|
||||
};
|
||||
|
||||
// Collect all unique ranges needed for the visible area
|
||||
for (let i = startIndex; i <= endIndex; i++) {
|
||||
const positionInfo = getPositionInfo(i, starredCount);
|
||||
ranges[positionInfo.collection].add(positionInfo.offset);
|
||||
}
|
||||
|
||||
// Trigger queries for each unique range
|
||||
for (const [collection, offsets] of objectEntries(ranges)) {
|
||||
for (const offset of offsets) {
|
||||
triggerGetImageCollection({
|
||||
collection,
|
||||
offset,
|
||||
limit: RANGE_SIZE,
|
||||
...queryArgs,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
[starredCount, queryArgs, triggerGetImageCollection]
|
||||
);
|
||||
|
||||
return prefetchRange;
|
||||
return `${JSON.stringify(queryArgs)}-${imageName}`;
|
||||
};
|
||||
|
||||
// Physical DOM-based grid calculation using refs (based on working old implementation)
|
||||
@@ -241,40 +137,72 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
};
|
||||
|
||||
// Check if an item at a given index is visible in the viewport
|
||||
const isItemVisible = (index: number, rootEl: HTMLDivElement): null | 'start' | 'center' | 'end' => {
|
||||
const scrollIntoView = (
|
||||
index: number,
|
||||
rootEl: HTMLDivElement,
|
||||
virtuosoGridHandle: VirtuosoGridHandle,
|
||||
range: ListRange
|
||||
) => {
|
||||
if (range.endIndex === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// First get the virtuoso grid list root element
|
||||
const gridList = rootEl.querySelector('.virtuoso-grid-list') as HTMLElement;
|
||||
|
||||
if (!gridList) {
|
||||
return null;
|
||||
// No grid - cannot scroll!
|
||||
return;
|
||||
}
|
||||
|
||||
// Then find the specific item within the grid list
|
||||
const targetItem = gridList.querySelector(`.virtuoso-grid-item[data-index="${index}"]`) as HTMLElement;
|
||||
|
||||
if (!targetItem) {
|
||||
return null;
|
||||
if (index > range.endIndex) {
|
||||
virtuosoGridHandle.scrollToIndex({
|
||||
index,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else if (index < range.startIndex) {
|
||||
virtuosoGridHandle.scrollToIndex({
|
||||
index,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
} else {
|
||||
log.warn(`Unable to find item index ${index} but it is in range ${range.startIndex}-${range.endIndex}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const itemRect = targetItem.getBoundingClientRect();
|
||||
const rootRect = rootEl.getBoundingClientRect();
|
||||
|
||||
if (itemRect.top < rootRect.top) {
|
||||
return 'start';
|
||||
virtuosoGridHandle.scrollToIndex({
|
||||
index,
|
||||
behavior: 'auto',
|
||||
align: 'start',
|
||||
});
|
||||
} else if (itemRect.bottom > rootRect.bottom) {
|
||||
virtuosoGridHandle.scrollToIndex({
|
||||
index,
|
||||
behavior: 'auto',
|
||||
align: 'end',
|
||||
});
|
||||
}
|
||||
|
||||
if (itemRect.bottom > rootRect.bottom) {
|
||||
return 'end';
|
||||
}
|
||||
|
||||
return 'center';
|
||||
return;
|
||||
};
|
||||
|
||||
// Hook for keyboard navigation using physical DOM measurements
|
||||
const useKeyboardNavigation = (
|
||||
imageNames: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
rootRef: React.RefObject<HTMLDivElement>
|
||||
rootRef: React.RefObject<HTMLDivElement>,
|
||||
rangeRef: MutableRefObject<ListRange>
|
||||
) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
@@ -291,7 +219,9 @@ const useKeyboardNavigation = (
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
const rootEl = rootRef.current;
|
||||
if (!rootEl) {
|
||||
const virtuosoGridHandle = virtuosoRef.current;
|
||||
const range = rangeRef.current;
|
||||
if (!rootEl || !virtuosoGridHandle) {
|
||||
return;
|
||||
}
|
||||
if (imageNames.length === 0) {
|
||||
@@ -358,21 +288,11 @@ const useKeyboardNavigation = (
|
||||
const newImageName = imageNames[newIndex];
|
||||
if (newImageName) {
|
||||
dispatch(selectionChanged([newImageName]));
|
||||
|
||||
// Only scroll if the selected item is not visible
|
||||
const vis = isItemVisible(newIndex, rootEl);
|
||||
if (!vis || vis === 'center') {
|
||||
return;
|
||||
}
|
||||
virtuosoRef.current?.scrollToIndex({
|
||||
index: newIndex,
|
||||
behavior: 'smooth',
|
||||
align: vis,
|
||||
});
|
||||
scrollIntoView(newIndex, rootEl, virtuosoGridHandle, range);
|
||||
}
|
||||
}
|
||||
},
|
||||
[rootRef, imageNames, currentIndex, dispatch, virtuosoRef]
|
||||
[rootRef, virtuosoRef, rangeRef, imageNames, currentIndex, dispatch]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -387,16 +307,11 @@ const useKeyboardNavigation = (
|
||||
export const NewGallery = memo(() => {
|
||||
const queryArgs = useDebouncedImageCollectionQueryArgs();
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
|
||||
// Get the ordered list of image names - this is our primary data source
|
||||
// Get the ordered list of image names - this is our primary data source for virtualization
|
||||
const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs);
|
||||
|
||||
// Get starred count for position calculations
|
||||
const { data: counts } = useGetImageCollectionCountsQuery(queryArgs);
|
||||
const starredCount = counts?.starred_count ?? 0;
|
||||
|
||||
const prefetchRange = usePrefetchRanges(starredCount, queryArgs);
|
||||
|
||||
// Reset scroll position when query parameters change
|
||||
useEffect(() => {
|
||||
if (virtuosoRef.current && imageNames.length > 0) {
|
||||
@@ -407,7 +322,7 @@ export const NewGallery = memo(() => {
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Enable keyboard navigation
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef, rangeRef);
|
||||
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
@@ -439,24 +354,25 @@ export const NewGallery = memo(() => {
|
||||
};
|
||||
}, [scroller, initialize, osInstance]);
|
||||
|
||||
// Handle range changes to prefetch data for visible + buffer areas
|
||||
const handleRangeChanged = useCallback(
|
||||
(range: ListRange) => {
|
||||
prefetchRange(range.startIndex, range.endIndex);
|
||||
},
|
||||
[prefetchRange]
|
||||
);
|
||||
// Handle range changes - RTK Query will automatically cache and manage loading
|
||||
const handleRangeChanged = useCallback((range: ListRange) => {
|
||||
rangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
const context = useMemo(
|
||||
() =>
|
||||
({
|
||||
imageNames,
|
||||
queryArgs,
|
||||
starredCount,
|
||||
}) satisfies GridContext,
|
||||
[imageNames, queryArgs, starredCount]
|
||||
[imageNames, queryArgs]
|
||||
);
|
||||
|
||||
// Item content function
|
||||
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName, ctx) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} queryArgs={ctx.queryArgs} />;
|
||||
}, []);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex height="100%" alignItems="center" justifyContent="center">
|
||||
|
||||
@@ -45,6 +45,7 @@ export const selectImageCollectionQueryArgs = createMemoizedSelector(selectGalle
|
||||
search_term: gallery.searchTerm || undefined,
|
||||
order_dir: gallery.orderDir as SQLiteDirection,
|
||||
is_intermediate: false,
|
||||
starred_first: true,
|
||||
}));
|
||||
export const selectAutoAssignBoardOnClick = createSelector(
|
||||
selectGallerySlice,
|
||||
|
||||
@@ -50,10 +50,10 @@ export const imagesApi = api.injectEndpoints({
|
||||
url: getListImagesUrl(queryArgs),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result, error, { board_id, categories }) => {
|
||||
providesTags: (result, error, queryArgs) => {
|
||||
return [
|
||||
// Make the tags the same as the cache key
|
||||
{ type: 'ImageList', id: getListImagesUrl({ board_id, categories }) },
|
||||
{ type: 'ImageList', id: JSON.stringify(queryArgs) },
|
||||
'FetchOnReconnect',
|
||||
];
|
||||
},
|
||||
@@ -493,6 +493,45 @@ export const imagesApi = api.injectEndpoints({
|
||||
}),
|
||||
providesTags: ['ImageNameList', 'FetchOnReconnect'],
|
||||
}),
|
||||
/**
|
||||
* Get paginated images with starred first (unified list)
|
||||
*/
|
||||
getUnifiedImageList: build.query<
|
||||
ListImagesResponse,
|
||||
{
|
||||
offset?: number;
|
||||
limit?: number;
|
||||
image_origin?: 'internal' | 'external' | null;
|
||||
categories?: ImageCategory[] | null;
|
||||
is_intermediate?: boolean | null;
|
||||
board_id?: string | null;
|
||||
search_term?: string | null;
|
||||
order_dir?: SQLiteDirection;
|
||||
}
|
||||
>({
|
||||
query: (queryArgs) => ({
|
||||
url: getListImagesUrl({ ...queryArgs, starred_first: true }),
|
||||
method: 'GET',
|
||||
}),
|
||||
providesTags: (result, error, { board_id, categories }) => [
|
||||
{ type: 'ImageList', id: getListImagesUrl({ board_id, categories }) },
|
||||
'FetchOnReconnect',
|
||||
],
|
||||
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
// Populate the getImageDTO cache with these images
|
||||
const res = await queryFulfilled;
|
||||
const imageDTOs = res.data.items;
|
||||
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
for (const imageDTO of imageDTOs) {
|
||||
updates.push({
|
||||
endpointName: 'getImageDTO',
|
||||
arg: imageDTO.image_name,
|
||||
value: imageDTO,
|
||||
});
|
||||
}
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -518,6 +557,7 @@ export const {
|
||||
useGetImageCollectionQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
useGetImageNamesQuery,
|
||||
useGetUnifiedImageListQuery,
|
||||
} = imagesApi;
|
||||
|
||||
/**
|
||||
|
||||
Reference in New Issue
Block a user