From 434d8a2b125f94ee7aa9d247662d7304fb543fcd Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 24 Jun 2025 23:00:24 +1000 Subject: [PATCH] refactor: gallery scroll (improved impl) --- .../gallery/components/NewGallery.tsx | 208 +++++++++++++++++- 1 file changed, 202 insertions(+), 6 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index b6fb9ed6ae..f07458b4ca 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -1,10 +1,12 @@ import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; import { logger } from 'app/logging/logger'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { selectGalleryImageMinimumWidth, selectImageCollectionQueryArgs, + selectLastSelectedImage, } from 'features/gallery/store/gallerySelectors'; +import { selectionChanged } from 'features/gallery/store/gallerySlice'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import type { @@ -43,6 +45,9 @@ type ImageCollectionQueryArgs = { const RANGE_SIZE = 50; const VIEWPORT_BUFFER = 2048; const SCROLL_SEEK_VELOCITY_THRESHOLD = 2048; +const DEBOUNCE_DELAY = 500; +const GRID_GAP = 2; +const SPINNER_OPACITY = 0.3; type GridContext = { queryArgs: ImageCollectionQueryArgs; @@ -99,7 +104,7 @@ const useImageFromBatch = ( selectFromResult: ({ data }) => { const imageDTO = data?.items?.[positionInfo.itemIndex] || null; if (imageDTO && imageDTO.image_name !== imageName) { - log.warnOnce(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`); + log.warn(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`); } return { imageDTO }; }, @@ -139,7 +144,7 @@ ImageAtPosition.displayName = 'ImageAtPosition'; export const useDebouncedImageCollectionQueryArgs = () => { const _queryArgs = useAppSelector(selectImageCollectionQueryArgs); - const [queryArgs] = useDebounce(_queryArgs, 500); + const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY); return queryArgs; }; @@ -191,6 +196,193 @@ const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQuery return prefetchRange; }; +// Physical DOM-based grid calculation using refs (based on working old implementation) +const getImagesPerRow = (rootEl: HTMLDivElement): number => { + // Start from root and find virtuoso grid elements + const gridElement = rootEl.querySelector('.virtuoso-grid-list'); + + if (!gridElement) { + return 0; + } + + const firstGridItem = gridElement.querySelector('.virtuoso-grid-item'); + + if (!firstGridItem) { + return 0; + } + + const itemRect = firstGridItem.getBoundingClientRect(); + const containerRect = gridElement.getBoundingClientRect(); + + // Get the computed gap from CSS + const gridStyle = window.getComputedStyle(gridElement); + const gapValue = gridStyle.gap; + const gap = parseFloat(gapValue); + + if (isNaN(gap) || !itemRect.width || !itemRect.height || !containerRect.width || !containerRect.height) { + return 0; + } + + // Use the exact calculation from the working old implementation + let imagesPerRow = 0; + let spaceUsed = 0; + + // Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes + // this, without the possibility of accidentally adding an extra column. + while (spaceUsed + itemRect.width <= containerRect.width + 1) { + imagesPerRow++; // Increment the number of images + spaceUsed += itemRect.width; // Add image size to the used space + if (spaceUsed + gap <= containerRect.width) { + spaceUsed += gap; // Add gap size to the used space after each image except after the last image + } + } + + return Math.max(1, imagesPerRow); +}; + +// Check if an item at a given index is visible in the viewport +const isItemVisible = (index: number, rootEl: HTMLDivElement): null | 'start' | 'center' | 'end' => { + // First get the virtuoso grid list root element + const gridList = rootEl.querySelector('.virtuoso-grid-list') as HTMLElement; + + if (!gridList) { + return null; + } + + // Then find the specific item within the grid list + const targetItem = gridList.querySelector(`.virtuoso-grid-item[data-index="${index}"]`) as HTMLElement; + + if (!targetItem) { + return null; + } + + const itemRect = targetItem.getBoundingClientRect(); + const rootRect = rootEl.getBoundingClientRect(); + + if (itemRect.top < rootRect.top) { + return 'start'; + } + + if (itemRect.bottom > rootRect.bottom) { + return 'end'; + } + + return 'center'; +}; + +// Hook for keyboard navigation using physical DOM measurements +const useKeyboardNavigation = ( + imageNames: string[], + virtuosoRef: React.RefObject, + rootRef: React.RefObject +) => { + const dispatch = useAppDispatch(); + const lastSelectedImage = useAppSelector(selectLastSelectedImage); + + // Get current index of selected image + const currentIndex = useMemo(() => { + if (!lastSelectedImage || imageNames.length === 0) { + return 0; + } + const index = imageNames.findIndex((name) => name === lastSelectedImage); + return index >= 0 ? index : 0; + }, [lastSelectedImage, imageNames]); + + const handleKeyDown = useCallback( + (event: KeyboardEvent) => { + const rootEl = rootRef.current; + if (!rootEl) { + return; + } + if (imageNames.length === 0) { + return; + } + + // Only handle arrow keys + if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) { + return; + } + + // Don't interfere if user is typing in an input + if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) { + return; + } + + const imagesPerRow = getImagesPerRow(rootEl); + + if (imagesPerRow === 0) { + // This can happen if the grid is not yet rendered or has no items + return; + } + + event.preventDefault(); + + let newIndex = currentIndex; + + switch (event.key) { + case 'ArrowLeft': + if (currentIndex > 0) { + newIndex = currentIndex - 1; + } else { + // Wrap to last image + newIndex = imageNames.length - 1; + } + break; + case 'ArrowRight': + if (currentIndex < imageNames.length - 1) { + newIndex = currentIndex + 1; + } else { + // Wrap to first image + newIndex = 0; + } + break; + case 'ArrowUp': + // If on first row, stay on current image + if (currentIndex < imagesPerRow) { + newIndex = currentIndex; + } else { + newIndex = Math.max(0, currentIndex - imagesPerRow); + } + break; + case 'ArrowDown': + // If no images below, stay on current image + if (currentIndex >= imageNames.length - imagesPerRow) { + newIndex = currentIndex; + } else { + newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow); + } + break; + } + + if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) { + const newImageName = imageNames[newIndex]; + if (newImageName) { + dispatch(selectionChanged([newImageName])); + + // Only scroll if the selected item is not visible + const vis = isItemVisible(newIndex, rootEl); + if (!vis || vis === 'center') { + return; + } + virtuosoRef.current?.scrollToIndex({ + index: newIndex, + behavior: 'smooth', + align: vis, + }); + } + } + }, + [rootRef, imageNames, currentIndex, dispatch, virtuosoRef] + ); + + useEffect(() => { + document.addEventListener('keydown', handleKeyDown); + return () => { + document.removeEventListener('keydown', handleKeyDown); + }; + }, [handleKeyDown]); +}; + // Main gallery component export const NewGallery = memo(() => { const queryArgs = useDebouncedImageCollectionQueryArgs(); @@ -213,6 +405,10 @@ export const NewGallery = memo(() => { }, [queryArgs, imageNames.length]); const rootRef = useRef(null); + + // Enable keyboard navigation + useKeyboardNavigation(imageNames, virtuosoRef, rootRef); + const [scroller, setScroller] = useState(null); const [initialize, osInstance] = useOverlayScrollbars({ defer: true, @@ -264,7 +460,7 @@ export const NewGallery = memo(() => { if (isLoading) { return ( - + Loading gallery... ); @@ -285,7 +481,7 @@ export const NewGallery = memo(() => { context={context} totalCount={imageNames.length} data={imageNames} - increaseViewportBy={VIEWPORT_BUFFER} + overscan={VIEWPORT_BUFFER} itemContent={itemContent} computeItemKey={computeItemKey} components={components} @@ -316,7 +512,7 @@ const ListComponent: GridComponents['List'] = forwardRef((props, re );