mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 19:24:57 -05:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import {
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectImageCollectionQueryArgs,
|
||||
selectLastSelectedImage,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import type {
|
||||
@@ -43,6 +45,9 @@ type ImageCollectionQueryArgs = {
|
||||
const RANGE_SIZE = 50;
|
||||
const VIEWPORT_BUFFER = 2048;
|
||||
const SCROLL_SEEK_VELOCITY_THRESHOLD = 2048;
|
||||
const DEBOUNCE_DELAY = 500;
|
||||
const GRID_GAP = 2;
|
||||
const SPINNER_OPACITY = 0.3;
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ImageCollectionQueryArgs;
|
||||
@@ -99,7 +104,7 @@ const useImageFromBatch = (
|
||||
selectFromResult: ({ data }) => {
|
||||
const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
|
||||
if (imageDTO && imageDTO.image_name !== imageName) {
|
||||
log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
|
||||
log.warn(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
|
||||
}
|
||||
return { imageDTO };
|
||||
},
|
||||
@@ -139,7 +144,7 @@ ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
export const useDebouncedImageCollectionQueryArgs = () => {
|
||||
const _queryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, 500);
|
||||
const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY);
|
||||
return queryArgs;
|
||||
};
|
||||
|
||||
@@ -191,6 +196,193 @@ const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQuery
|
||||
return prefetchRange;
|
||||
};
|
||||
|
||||
// Physical DOM-based grid calculation using refs (based on working old implementation)
|
||||
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
// Start from root and find virtuoso grid elements
|
||||
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
|
||||
|
||||
if (!gridElement) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const firstGridItem = gridElement.querySelector('.virtuoso-grid-item');
|
||||
|
||||
if (!firstGridItem) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const itemRect = firstGridItem.getBoundingClientRect();
|
||||
const containerRect = gridElement.getBoundingClientRect();
|
||||
|
||||
// Get the computed gap from CSS
|
||||
const gridStyle = window.getComputedStyle(gridElement);
|
||||
const gapValue = gridStyle.gap;
|
||||
const gap = parseFloat(gapValue);
|
||||
|
||||
if (isNaN(gap) || !itemRect.width || !itemRect.height || !containerRect.width || !containerRect.height) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Use the exact calculation from the working old implementation
|
||||
let imagesPerRow = 0;
|
||||
let spaceUsed = 0;
|
||||
|
||||
// Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes
|
||||
// this, without the possibility of accidentally adding an extra column.
|
||||
while (spaceUsed + itemRect.width <= containerRect.width + 1) {
|
||||
imagesPerRow++; // Increment the number of images
|
||||
spaceUsed += itemRect.width; // Add image size to the used space
|
||||
if (spaceUsed + gap <= containerRect.width) {
|
||||
spaceUsed += gap; // Add gap size to the used space after each image except after the last image
|
||||
}
|
||||
}
|
||||
|
||||
return Math.max(1, imagesPerRow);
|
||||
};
|
||||
|
||||
// Check if an item at a given index is visible in the viewport
|
||||
const isItemVisible = (index: number, rootEl: HTMLDivElement): null | 'start' | 'center' | 'end' => {
|
||||
// First get the virtuoso grid list root element
|
||||
const gridList = rootEl.querySelector('.virtuoso-grid-list') as HTMLElement;
|
||||
|
||||
if (!gridList) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Then find the specific item within the grid list
|
||||
const targetItem = gridList.querySelector(`.virtuoso-grid-item[data-index="${index}"]`) as HTMLElement;
|
||||
|
||||
if (!targetItem) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const itemRect = targetItem.getBoundingClientRect();
|
||||
const rootRect = rootEl.getBoundingClientRect();
|
||||
|
||||
if (itemRect.top < rootRect.top) {
|
||||
return 'start';
|
||||
}
|
||||
|
||||
if (itemRect.bottom > rootRect.bottom) {
|
||||
return 'end';
|
||||
}
|
||||
|
||||
return 'center';
|
||||
};
|
||||
|
||||
// Hook for keyboard navigation using physical DOM measurements
|
||||
const useKeyboardNavigation = (
|
||||
imageNames: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
rootRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
|
||||
// Get current index of selected image
|
||||
const currentIndex = useMemo(() => {
|
||||
if (!lastSelectedImage || imageNames.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
const index = imageNames.findIndex((name) => name === lastSelectedImage);
|
||||
return index >= 0 ? index : 0;
|
||||
}, [lastSelectedImage, imageNames]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent) => {
|
||||
const rootEl = rootRef.current;
|
||||
if (!rootEl) {
|
||||
return;
|
||||
}
|
||||
if (imageNames.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only handle arrow keys
|
||||
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't interfere if user is typing in an input
|
||||
if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) {
|
||||
return;
|
||||
}
|
||||
|
||||
const imagesPerRow = getImagesPerRow(rootEl);
|
||||
|
||||
if (imagesPerRow === 0) {
|
||||
// This can happen if the grid is not yet rendered or has no items
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
let newIndex = currentIndex;
|
||||
|
||||
switch (event.key) {
|
||||
case 'ArrowLeft':
|
||||
if (currentIndex > 0) {
|
||||
newIndex = currentIndex - 1;
|
||||
} else {
|
||||
// Wrap to last image
|
||||
newIndex = imageNames.length - 1;
|
||||
}
|
||||
break;
|
||||
case 'ArrowRight':
|
||||
if (currentIndex < imageNames.length - 1) {
|
||||
newIndex = currentIndex + 1;
|
||||
} else {
|
||||
// Wrap to first image
|
||||
newIndex = 0;
|
||||
}
|
||||
break;
|
||||
case 'ArrowUp':
|
||||
// If on first row, stay on current image
|
||||
if (currentIndex < imagesPerRow) {
|
||||
newIndex = currentIndex;
|
||||
} else {
|
||||
newIndex = Math.max(0, currentIndex - imagesPerRow);
|
||||
}
|
||||
break;
|
||||
case 'ArrowDown':
|
||||
// If no images below, stay on current image
|
||||
if (currentIndex >= imageNames.length - imagesPerRow) {
|
||||
newIndex = currentIndex;
|
||||
} else {
|
||||
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
|
||||
const newImageName = imageNames[newIndex];
|
||||
if (newImageName) {
|
||||
dispatch(selectionChanged([newImageName]));
|
||||
|
||||
// Only scroll if the selected item is not visible
|
||||
const vis = isItemVisible(newIndex, rootEl);
|
||||
if (!vis || vis === 'center') {
|
||||
return;
|
||||
}
|
||||
virtuosoRef.current?.scrollToIndex({
|
||||
index: newIndex,
|
||||
behavior: 'smooth',
|
||||
align: vis,
|
||||
});
|
||||
}
|
||||
}
|
||||
},
|
||||
[rootRef, imageNames, currentIndex, dispatch, virtuosoRef]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
document.addEventListener('keydown', handleKeyDown);
|
||||
return () => {
|
||||
document.removeEventListener('keydown', handleKeyDown);
|
||||
};
|
||||
}, [handleKeyDown]);
|
||||
};
|
||||
|
||||
// Main gallery component
|
||||
export const NewGallery = memo(() => {
|
||||
const queryArgs = useDebouncedImageCollectionQueryArgs();
|
||||
@@ -213,6 +405,10 @@ export const NewGallery = memo(() => {
|
||||
}, [queryArgs, imageNames.length]);
|
||||
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Enable keyboard navigation
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
|
||||
|
||||
const [scroller, setScroller] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
defer: true,
|
||||
@@ -264,7 +460,7 @@ export const NewGallery = memo(() => {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex height="100%" alignItems="center" justifyContent="center">
|
||||
<Spinner size="lg" opacity={0.3} />
|
||||
<Spinner size="lg" opacity={SPINNER_OPACITY} />
|
||||
<Text ml={4}>Loading gallery...</Text>
|
||||
</Flex>
|
||||
);
|
||||
@@ -285,7 +481,7 @@ export const NewGallery = memo(() => {
|
||||
context={context}
|
||||
totalCount={imageNames.length}
|
||||
data={imageNames}
|
||||
increaseViewportBy={VIEWPORT_BUFFER}
|
||||
overscan={VIEWPORT_BUFFER}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
@@ -316,7 +512,7 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef((props, re
|
||||
<Grid
|
||||
ref={ref}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
gap={2}
|
||||
gap={GRID_GAP}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user