refactor: gallery scroll (improved impl)

This commit is contained in:
psychedelicious
2025-06-24 23:00:24 +10:00
parent f55c593705
commit 434d8a2b12

View File

@@ -1,10 +1,12 @@
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { logger } from 'app/logging/logger';
import { useAppSelector } from 'app/store/storeHooks';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import {
selectGalleryImageMinimumWidth,
selectImageCollectionQueryArgs,
selectLastSelectedImage,
} from 'features/gallery/store/gallerySelectors';
import { selectionChanged } from 'features/gallery/store/gallerySlice';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type {
@@ -43,6 +45,9 @@ type ImageCollectionQueryArgs = {
const RANGE_SIZE = 50;
const VIEWPORT_BUFFER = 2048;
const SCROLL_SEEK_VELOCITY_THRESHOLD = 2048;
const DEBOUNCE_DELAY = 500;
const GRID_GAP = 2;
const SPINNER_OPACITY = 0.3;
type GridContext = {
queryArgs: ImageCollectionQueryArgs;
@@ -99,7 +104,7 @@ const useImageFromBatch = (
selectFromResult: ({ data }) => {
const imageDTO = data?.items?.[positionInfo.itemIndex] || null;
if (imageDTO && imageDTO.image_name !== imageName) {
log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
log.warn(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`);
}
return { imageDTO };
},
@@ -139,7 +144,7 @@ ImageAtPosition.displayName = 'ImageAtPosition';
export const useDebouncedImageCollectionQueryArgs = () => {
const _queryArgs = useAppSelector(selectImageCollectionQueryArgs);
const [queryArgs] = useDebounce(_queryArgs, 500);
const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY);
return queryArgs;
};
@@ -191,6 +196,193 @@ const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQuery
return prefetchRange;
};
// Physical DOM-based grid calculation using refs (based on working old implementation)
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
// Start from root and find virtuoso grid elements
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
if (!gridElement) {
return 0;
}
const firstGridItem = gridElement.querySelector('.virtuoso-grid-item');
if (!firstGridItem) {
return 0;
}
const itemRect = firstGridItem.getBoundingClientRect();
const containerRect = gridElement.getBoundingClientRect();
// Get the computed gap from CSS
const gridStyle = window.getComputedStyle(gridElement);
const gapValue = gridStyle.gap;
const gap = parseFloat(gapValue);
if (isNaN(gap) || !itemRect.width || !itemRect.height || !containerRect.width || !containerRect.height) {
return 0;
}
// Use the exact calculation from the working old implementation
let imagesPerRow = 0;
let spaceUsed = 0;
// Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes
// this, without the possibility of accidentally adding an extra column.
while (spaceUsed + itemRect.width <= containerRect.width + 1) {
imagesPerRow++; // Increment the number of images
spaceUsed += itemRect.width; // Add image size to the used space
if (spaceUsed + gap <= containerRect.width) {
spaceUsed += gap; // Add gap size to the used space after each image except after the last image
}
}
return Math.max(1, imagesPerRow);
};
// Check if an item at a given index is visible in the viewport
const isItemVisible = (index: number, rootEl: HTMLDivElement): null | 'start' | 'center' | 'end' => {
// First get the virtuoso grid list root element
const gridList = rootEl.querySelector('.virtuoso-grid-list') as HTMLElement;
if (!gridList) {
return null;
}
// Then find the specific item within the grid list
const targetItem = gridList.querySelector(`.virtuoso-grid-item[data-index="${index}"]`) as HTMLElement;
if (!targetItem) {
return null;
}
const itemRect = targetItem.getBoundingClientRect();
const rootRect = rootEl.getBoundingClientRect();
if (itemRect.top < rootRect.top) {
return 'start';
}
if (itemRect.bottom > rootRect.bottom) {
return 'end';
}
return 'center';
};
// Hook for keyboard navigation using physical DOM measurements
const useKeyboardNavigation = (
imageNames: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
rootRef: React.RefObject<HTMLDivElement>
) => {
const dispatch = useAppDispatch();
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
// Get current index of selected image
const currentIndex = useMemo(() => {
if (!lastSelectedImage || imageNames.length === 0) {
return 0;
}
const index = imageNames.findIndex((name) => name === lastSelectedImage);
return index >= 0 ? index : 0;
}, [lastSelectedImage, imageNames]);
const handleKeyDown = useCallback(
(event: KeyboardEvent) => {
const rootEl = rootRef.current;
if (!rootEl) {
return;
}
if (imageNames.length === 0) {
return;
}
// Only handle arrow keys
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
return;
}
// Don't interfere if user is typing in an input
if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) {
return;
}
const imagesPerRow = getImagesPerRow(rootEl);
if (imagesPerRow === 0) {
// This can happen if the grid is not yet rendered or has no items
return;
}
event.preventDefault();
let newIndex = currentIndex;
switch (event.key) {
case 'ArrowLeft':
if (currentIndex > 0) {
newIndex = currentIndex - 1;
} else {
// Wrap to last image
newIndex = imageNames.length - 1;
}
break;
case 'ArrowRight':
if (currentIndex < imageNames.length - 1) {
newIndex = currentIndex + 1;
} else {
// Wrap to first image
newIndex = 0;
}
break;
case 'ArrowUp':
// If on first row, stay on current image
if (currentIndex < imagesPerRow) {
newIndex = currentIndex;
} else {
newIndex = Math.max(0, currentIndex - imagesPerRow);
}
break;
case 'ArrowDown':
// If no images below, stay on current image
if (currentIndex >= imageNames.length - imagesPerRow) {
newIndex = currentIndex;
} else {
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
}
break;
}
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
const newImageName = imageNames[newIndex];
if (newImageName) {
dispatch(selectionChanged([newImageName]));
// Only scroll if the selected item is not visible
const vis = isItemVisible(newIndex, rootEl);
if (!vis || vis === 'center') {
return;
}
virtuosoRef.current?.scrollToIndex({
index: newIndex,
behavior: 'smooth',
align: vis,
});
}
}
},
[rootRef, imageNames, currentIndex, dispatch, virtuosoRef]
);
useEffect(() => {
document.addEventListener('keydown', handleKeyDown);
return () => {
document.removeEventListener('keydown', handleKeyDown);
};
}, [handleKeyDown]);
};
// Main gallery component
export const NewGallery = memo(() => {
const queryArgs = useDebouncedImageCollectionQueryArgs();
@@ -213,6 +405,10 @@ export const NewGallery = memo(() => {
}, [queryArgs, imageNames.length]);
const rootRef = useRef<HTMLDivElement>(null);
// Enable keyboard navigation
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
@@ -264,7 +460,7 @@ export const NewGallery = memo(() => {
if (isLoading) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Spinner size="lg" opacity={0.3} />
<Spinner size="lg" opacity={SPINNER_OPACITY} />
<Text ml={4}>Loading gallery...</Text>
</Flex>
);
@@ -285,7 +481,7 @@ export const NewGallery = memo(() => {
context={context}
totalCount={imageNames.length}
data={imageNames}
increaseViewportBy={VIEWPORT_BUFFER}
overscan={VIEWPORT_BUFFER}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
@@ -316,7 +512,7 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef((props, re
<Grid
ref={ref}
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
gap={2}
gap={GRID_GAP}
{...props}
/>
);