refactor: gallery scroll (improved impl)

This commit is contained in:
psychedelicious
2025-06-24 19:58:59 +10:00
parent 054428730d
commit a339cec36f
9 changed files with 289 additions and 224 deletions

View File

@@ -1,3 +1,4 @@
from time import time
from typing import Literal, Optional
from PIL.Image import Image as PILImageType

View File

@@ -1,6 +1,7 @@
import '@fontsource-variable/inter';
import 'overlayscrollbars/overlayscrollbars.css';
import '@xyflow/react/dist/base.css';
import 'common/components/OverlayScrollbars/overlayscrollbars.css';
import { ChakraProvider, DarkMode, extendTheme, theme as _theme, TOAST_OPTIONS } from '@invoke-ai/ui-library';
import type { ReactNode } from 'react';

View File

@@ -1,9 +1,72 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import type { RootState } from 'app/store/store';
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { uniq } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types';
// Type for image collection query arguments
type ImageCollectionQueryArgs = {
board_id?: string;
categories?: ImageCategory[];
search_term?: string;
order_dir?: SQLiteDirection;
is_intermediate: boolean;
};
/**
* Helper function to get all cached image data from collection queries
* Returns a combined array of starred images followed by unstarred images
*/
const getCachedImageList = (state: RootState, queryArgs: ImageCollectionQueryArgs): ImageDTO[] => {
const countsQueryResult = imagesApi.endpoints.getImageCollectionCounts.select(queryArgs)(state);
if (!countsQueryResult.data) {
return [];
}
const starredCount = countsQueryResult.data.starred_count ?? 0;
const totalCount = countsQueryResult.data.total_count ?? 0;
const unstarredCount = totalCount - starredCount;
const imageDTOs: ImageDTO[] = [];
// Add starred images first (in order)
if (starredCount > 0) {
for (let offset = 0; offset < starredCount; offset += 50) {
const queryResult = imagesApi.endpoints.getImageCollection.select({
collection: 'starred',
offset,
limit: 50,
...queryArgs,
})(state);
if (queryResult.data?.items) {
imageDTOs.push(...queryResult.data.items);
}
}
}
// Add unstarred images (in order)
if (unstarredCount > 0) {
for (let offset = 0; offset < unstarredCount; offset += 50) {
const queryResult = imagesApi.endpoints.getImageCollection.select({
collection: 'unstarred',
offset,
limit: 50,
...queryArgs,
})(state);
if (queryResult.data?.items) {
imageDTOs.push(...queryResult.data.items);
}
}
}
return imageDTOs;
};
export const galleryImageClicked = createAction<{
imageName: string;
@@ -30,15 +93,21 @@ export const addGalleryImageClickedListener = (startAppListening: AppStartListen
effect: (action, { dispatch, getState }) => {
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImagesQueryArgs(state);
const queryResult = imagesApi.endpoints.listImages.select(queryArgs)(state);
const queryArgs = selectImageCollectionQueryArgs(state);
if (!queryResult.data) {
// Should never happen if we have clicked a gallery image
// Get all cached image data
const imageDTOs = getCachedImageList(state, queryArgs);
// If we don't have the image data cached, we can't perform selection operations
// This can happen if the user clicks on an image before all data is loaded
if (imageDTOs.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([imageName]));
}
return;
}
const imageDTOs = queryResult.data.items;
const selection = state.gallery.selection;
if (altKey) {

View File

@@ -170,6 +170,8 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
// serializableCheck: false,
// immutableCheck: false,
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
})

View File

@@ -1,6 +1,6 @@
.os-scrollbar {
/* The size of the scrollbar */
--os-size: 9px;
--os-size: 7px;
/* The axis-perpedicular padding of the scrollbar (horizontal: padding-y, vertical: padding-x) */
/* --os-padding-perpendicular: 0; */
/* The axis padding of the scrollbar (horizontal: padding-x, vertical: padding-y) */
@@ -22,11 +22,11 @@
/* The border radius of the scrollbar handle */
/* --os-handle-border-radius: 2px; */
/* The background of the scrollbar handle */
/* --os-handle-bg: var(--invokeai-colors-accentAlpha-500); */
--os-handle-bg: var(--invoke-colors-base-600);
/* The :hover background of the scrollbar handle */
/* --os-handle-bg-hover: var(--invokeai-colors-accentAlpha-700); */
--os-handle-bg-hover: var(--invoke-colors-base-500);
/* The :active background of the scrollbar handle */
/* --os-handle-bg-active: var(--invokeai-colors-accentAlpha-800); */
--os-handle-bg-active: var(--invoke-colors-base-400);
/* The border of the scrollbar handle */
/* --os-handle-border: none; */
/* The :hover border of the scrollbar handle */
@@ -34,7 +34,7 @@
/* The :active border of the scrollbar handle */
/* --os-handle-border-active: none; */
/* The min size of the scrollbar handle */
--os-handle-min-size: 50px;
/* --os-handle-min-size: 50px; */
/* The max size of the scrollbar handle */
/* --os-handle-max-size: none; */
/* The axis-perpedicular size of the scrollbar handle (horizontal: height, vertical: width) */
@@ -48,9 +48,9 @@
}
.os-scrollbar-handle {
cursor: grab;
/* cursor: grab; */
}
.os-scrollbar-handle:active {
cursor: grabbing;
/* cursor: grabbing; */
}

View File

@@ -70,7 +70,7 @@ export const GalleryPanel = memo(() => {
return (
<Flex flexDirection="column" alignItems="center" justifyContent="space-between" h="full" w="full" p={2} minH={0}>
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full">
<Tabs index={galleryView === 'images' ? 0 : 1} variant="enclosed" display="flex" flexDir="column" w="full" pb={2}>
<TabList gap={2} fontSize="sm" borderColor="base.800" alignItems="center" w="full">
<Text fontSize="sm" fontWeight="semibold" noOfLines={1} px="2" wordBreak="break-all">
{boardName}
@@ -89,6 +89,7 @@ export const GalleryPanel = memo(() => {
<Flex h="full" justifyContent="flex-end">
<GalleryUploadButton />
<GallerySettingsPopover />
<IconButton
size="sm"
variant="link"
@@ -100,17 +101,17 @@ export const GalleryPanel = memo(() => {
/>
</Flex>
</TabList>
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
<Box w="full" pt={2}>
<GallerySearch
searchTerm={searchTerm}
onChangeSearchTerm={onChangeSearchTerm}
onResetSearchTerm={onResetSearchTerm}
/>
</Box>
</Collapse>
</Tabs>
<Collapse in={searchDisclosure.isOpen} style={COLLAPSE_STYLES}>
<Box w="full" pt={2}>
<GallerySearch
searchTerm={searchTerm}
onChangeSearchTerm={onChangeSearchTerm}
onResetSearchTerm={onResetSearchTerm}
/>
</Box>
</Collapse>
{/* <GalleryImageGrid />
<GalleryPagination /> */}
<NewGallery />

View File

@@ -1,11 +1,10 @@
import { IconButton, Input, InputGroup, InputRightElement, Spinner } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectListImagesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { useDebouncedImageCollectionQueryArgs } from 'features/gallery/components/NewGallery';
import type { ChangeEvent, KeyboardEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useListImagesQuery } from 'services/api/endpoints/images';
import { useGetImageCollectionCountsQuery } from 'services/api/endpoints/images';
type Props = {
searchTerm: string;
@@ -15,10 +14,8 @@ type Props = {
export const GallerySearch = memo(({ searchTerm, onChangeSearchTerm, onResetSearchTerm }: Props) => {
const { t } = useTranslation();
const queryArgs = useAppSelector(selectListImagesQueryArgs);
const { isPending } = useListImagesQuery(queryArgs, {
selectFromResult: ({ isLoading, isFetching }) => ({ isPending: isLoading || isFetching }),
});
const queryArgs = useDebouncedImageCollectionQueryArgs();
const { isFetching } = useGetImageCollectionCountsQuery(queryArgs);
const handleChangeInput = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
@@ -46,12 +43,12 @@ export const GallerySearch = memo(({ searchTerm, onChangeSearchTerm, onResetSear
data-testid="image-search-input"
onKeyDown={handleKeydown}
/>
{isPending && (
{isFetching && (
<InputRightElement h="full" pe={2}>
<Spinner size="sm" opacity={0.5} />
</InputRightElement>
)}
{!isPending && searchTerm.length && (
{!isFetching && searchTerm.length && (
<InputRightElement h="full" pe={2}>
<IconButton
onClick={onResetSearchTerm}

View File

@@ -1,8 +1,7 @@
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { selectSearchTerm } from 'features/gallery/store/gallerySelectors';
import { searchTermChanged } from 'features/gallery/store/gallerySlice';
import { debounce } from 'lodash-es';
import { useCallback, useMemo, useState } from 'react';
import { useCallback } from 'react';
export const useGallerySearchTerm = () => {
// Highlander!
@@ -11,27 +10,16 @@ export const useGallerySearchTerm = () => {
const dispatch = useAppDispatch();
const searchTerm = useAppSelector(selectSearchTerm);
const [localSearchTerm, setLocalSearchTerm] = useState(searchTerm);
const debouncedSetSearchTerm = useMemo(() => {
return debounce((val: string) => {
dispatch(searchTermChanged(val));
}, 1000);
}, [dispatch]);
const onChange = useCallback(
(val: string) => {
setLocalSearchTerm(val);
debouncedSetSearchTerm(val);
dispatch(searchTermChanged(val));
},
[debouncedSetSearchTerm]
[dispatch]
);
const onReset = useCallback(() => {
debouncedSetSearchTerm.cancel();
setLocalSearchTerm('');
dispatch(searchTermChanged(''));
}, [debouncedSetSearchTerm, dispatch]);
}, [dispatch]);
return [localSearchTerm, onChange, onReset] as const;
return [searchTerm, onChange, onReset] as const;
};

View File

@@ -1,207 +1,190 @@
import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectGalleryImageMinimumWidth,
selectImageCollectionQueryArgs,
} from 'features/gallery/store/gallerySelectors';
import { memo, useCallback } from 'react';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type { GridComponents, ListRange, ScrollSeekConfiguration, VirtuosoGridHandle } from 'react-virtuoso';
import { VirtuosoGrid } from 'react-virtuoso';
import {
useGetImageCollectionCountsQuery,
useGetImageCollectionQuery,
useLazyGetImageCollectionQuery,
} from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
import type { ImageCategory, SQLiteDirection } from 'services/api/types';
import { useDebounce } from 'use-debounce';
// Types for range management
import { GalleryImage } from './ImageGrid/GalleryImage';
// Type for image collection query arguments
type ImageCollectionQueryArgs = {
board_id?: string;
categories?: ImageCategory[];
search_term?: string;
order_dir?: SQLiteDirection;
is_intermediate: boolean;
};
// Types
type Collection = 'starred' | 'unstarred';
interface RangeKey {
interface PositionInfo {
collection: Collection;
offset: number;
limit: number;
itemIndex: number;
}
interface PositionQuery extends RangeKey {
imageIndex: number;
}
type PositionInfo = {
totalCount: number;
starredCount: number;
unstarredCount: number;
starredEnd: number;
};
// Query options factory functions to prevent recreation on every render
const countsQueryOptions = {
selectFromResult: ({ data, isLoading }) => {
const positionInfo: PositionInfo | null = data
? {
totalCount: data.total_count ?? 0,
starredCount: data.starred_count ?? 0,
unstarredCount: data.unstarred_count ?? 0,
starredEnd: (data.starred_count ?? 0) - 1,
}
: null;
// Constants
const RANGE_SIZE = 50;
// Helper to calculate which collection and range an index belongs to
const getPositionInfo = (index: number, starredCount: number): PositionInfo => {
if (index < starredCount) {
// Starred collection
const offset = Math.floor(index / RANGE_SIZE) * RANGE_SIZE;
return {
positionInfo,
isLoading,
collection: 'starred',
offset,
itemIndex: index - offset,
};
},
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
const createImageCollectionQueryOptions = (queryParams: PositionQuery | null) =>
({
skip: !queryParams,
selectFromResult: (result) => {
return {
imageDTO: (queryParams && result.data?.items?.[queryParams.imageIndex]) || null,
};
},
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1];
// Placeholder image component for now
const ImagePlaceholder = memo(({ imageDTO }: { imageDTO: ImageDTO }) => (
<Image src={imageDTO.thumbnail_url} w="full" h="full" objectFit="contain" />
));
ImagePlaceholder.displayName = 'ImagePlaceholder';
// Loading skeleton component
const ImageSkeleton = memo(() => <Skeleton w="full" h="full" />);
ImageSkeleton.displayName = 'ImageSkeleton';
// Hook to manage position calculations and range loading
const useVirtualImageData = () => {
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
// Get position info derived from counts using selectFromResult
const { positionInfo, isLoading } = useGetImageCollectionCountsQuery(queryArgs, countsQueryOptions);
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
// Function to get query params for a specific position
const getQueryParamsForPosition = useCallback(
(index: number): PositionQuery | null => {
if (!positionInfo) {
return null;
}
if (positionInfo.starredCount === 0 || index >= positionInfo.starredCount) {
// This position is in the unstarred collection
const unstarredOffset = index - positionInfo.starredCount;
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
return {
collection: 'unstarred',
offset: rangeOffset,
limit: 50,
imageIndex: unstarredOffset % 50,
};
} else {
// This position is in the starred collection
const rangeOffset = Math.floor(index / 50) * 50;
return {
collection: 'starred',
offset: rangeOffset,
limit: 50,
imageIndex: index % 50,
};
}
},
[positionInfo]
);
// Function to calculate required ranges for a viewport and trigger lazy queries
const updateRequiredRanges = useCallback(
(startIndex: number, endIndex: number) => {
if (!positionInfo) {
return;
}
for (let i = startIndex; i <= endIndex; i++) {
const queryParams = getQueryParamsForPosition(i);
if (queryParams) {
const { collection, offset, limit } = queryParams;
triggerGetImageCollection(
{
collection,
offset,
limit,
...queryArgs,
},
true
);
}
}
},
[positionInfo, getQueryParamsForPosition, triggerGetImageCollection, queryArgs]
);
return {
positionInfo,
isLoading,
getQueryParamsForPosition,
queryArgs,
updateRequiredRanges,
};
} else {
// Unstarred collection
const unstarredIndex = index - starredCount;
const offset = Math.floor(unstarredIndex / RANGE_SIZE) * RANGE_SIZE;
return {
collection: 'unstarred',
offset,
itemIndex: unstarredIndex - offset,
};
}
};
// Hook to get image data for a specific position using selectFromResult
const useImageAtPosition = (index: number) => {
const { getQueryParamsForPosition, queryArgs } = useVirtualImageData();
// Hook to get image at a specific position
const useImageAtPosition = (index: number, starredCount: number, queryArgs: ImageCollectionQueryArgs) => {
const positionInfo = useMemo(() => getPositionInfo(index, starredCount), [index, starredCount]);
const queryParams = getQueryParamsForPosition(index);
const { imageDTO } = useGetImageCollectionQuery(
queryParams
? {
collection: queryParams.collection,
offset: queryParams.offset,
limit: queryParams.limit,
...queryArgs,
}
: skipToken,
createImageCollectionQueryOptions(queryParams)
const arg = useMemo(
() =>
({
collection: positionInfo.collection,
offset: positionInfo.offset,
limit: RANGE_SIZE,
...queryArgs,
}) satisfies Parameters<typeof useGetImageCollectionQuery>[0],
[positionInfo.collection, positionInfo.offset, queryArgs]
);
const options = useMemo(
() =>
({
selectFromResult: ({ data }) => {
if (!data) {
return { imageDTO: null };
} else {
return {
imageDTO: data.items[positionInfo.itemIndex] || null,
};
}
},
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1],
[positionInfo.itemIndex]
);
const { imageDTO } = useGetImageCollectionQuery(arg, options);
return imageDTO;
};
// Component to render a single image at a position
const ImageAtPosition = memo(({ index }: { index: number }) => {
const imageDTO = useImageAtPosition(index);
type ImageAtPositionProps = {
index: number;
starredCount: number;
queryArgs: ImageCollectionQueryArgs;
};
if (imageDTO) {
return <ImagePlaceholder imageDTO={imageDTO} />;
// Individual image component
const ImageAtPosition = memo(({ index, starredCount, queryArgs }: ImageAtPositionProps) => {
const imageDTO = useImageAtPosition(index, starredCount, queryArgs);
if (!imageDTO) {
return <Skeleton w="full" h="full" />;
}
return <ImageSkeleton />;
return <GalleryImage imageDTO={imageDTO} />;
});
ImageAtPosition.displayName = 'ImageAtPosition';
export const NewGallery = memo(() => {
const { positionInfo, isLoading, updateRequiredRanges } = useVirtualImageData();
export const useDebouncedImageCollectionQueryArgs = () => {
const _queryArgs = useAppSelector(selectImageCollectionQueryArgs);
const [queryArgs] = useDebounce(_queryArgs, 500);
return queryArgs;
};
// Handle range changes from VirtuosoGrid
const handleRangeChanged = useCallback(
(range: { startIndex: number; endIndex: number }) => {
updateRequiredRanges(range.startIndex, range.endIndex);
// Main gallery component
export const NewGallery = memo(() => {
const queryArgs = useDebouncedImageCollectionQueryArgs();
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const { data: counts, isLoading } = useGetImageCollectionCountsQuery(queryArgs);
const starredCount = counts?.starred_count ?? 0;
const totalCount = counts?.total_count ?? 0;
// Reset scroll position when query parameters change
useEffect(() => {
if (virtuosoRef.current && totalCount > 0) {
virtuosoRef.current.scrollToIndex({ index: 0, behavior: 'auto' });
}
}, [queryArgs, totalCount]);
// Memoized item content function
const itemContent = useCallback(
(index: number) => {
return <ImageAtPosition index={index} starredCount={starredCount} queryArgs={queryArgs} />;
},
[updateRequiredRanges]
[starredCount, queryArgs]
);
// Render item at specific index
const itemContent = useCallback((index: number) => {
return <ImageAtPosition index={index} />;
// Memoized compute key function
const computeItemKey = useCallback(
(index: number) => {
return `${JSON.stringify(queryArgs)}-${index}`;
},
[queryArgs]
);
// Handle range changes (for prefetching)
const handleRangeChanged = useCallback((_range: ListRange) => {
// RTK Query will automatically handle caching and deduplication
// No need to manually trigger queries here
}, []);
// Compute item key using position index - let RTK Query handle the caching
const computeItemKey = useCallback((index: number) => `position-${index}`, []);
const rootRef = useRef<HTMLDivElement>(null);
const [scroller, setScroller] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
events: {
initialized(osInstance) {
// force overflow styles
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
},
},
});
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => osInstance()?.destroy();
}, [scroller, initialize, osInstance]);
if (isLoading) {
return (
@@ -212,7 +195,7 @@ export const NewGallery = memo(() => {
);
}
if (!positionInfo || positionInfo.totalCount === 0) {
if (totalCount === 0) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Text color="gray.500">No images found</Text>
@@ -221,15 +204,18 @@ export const NewGallery = memo(() => {
}
return (
<Box height="100%" width="100%">
<Box data-overlayscrollbars-initialize="" ref={rootRef} w="full" h="full">
<VirtuosoGrid
totalCount={positionInfo.totalCount}
ref={virtuosoRef}
totalCount={totalCount}
increaseViewportBy={1024}
rangeChanged={handleRangeChanged}
itemContent={itemContent}
style={style}
computeItemKey={computeItemKey}
components={components}
style={style}
scrollerRef={setScroller}
scrollSeekConfiguration={scrollSeekConfiguration}
/>
</Box>
);
@@ -237,9 +223,18 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => {
return velocity > 500;
},
exit: (velocity) => velocity < 500,
};
// Styles
const style = { height: '100%', width: '100%' };
const ListComponent = forwardRef((props, ref) => {
// Grid components
const ListComponent: GridComponents['List'] = forwardRef((props, ref) => {
const galleryImageMinimumWidth = useAppSelector(selectGalleryImageMinimumWidth);
return (
@@ -247,15 +242,26 @@ const ListComponent = forwardRef((props, ref) => {
ref={ref}
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
gap={2}
padding={2}
{...props}
/>
);
});
ListComponent.displayName = 'ListComponent';
const ItemComponent = forwardRef((props, ref) => <GridItem ref={ref} aspectRatio="1/1" {...props} />);
const ItemComponent: GridComponents['Item'] = forwardRef((props, ref) => (
<GridItem ref={ref} aspectRatio="1/1" {...props} />
));
ItemComponent.displayName = 'ItemComponent';
const components = {
const FillSkeleton: GridComponents['ScrollSeekPlaceholder'] = forwardRef((props, ref) => (
<GridItem ref={ref} {...props}>
<Skeleton w="full" h="full" />
</GridItem>
));
FillSkeleton.displayName = 'FillSkeleton';
const components: GridComponents = {
Item: ItemComponent,
List: ListComponent,
ScrollSeekPlaceholder: FillSkeleton,
};