refactor: gallery scroll (improved impl)

This commit is contained in:
psychedelicious
2025-06-24 17:23:04 +10:00
parent 2c8ce6f2f4
commit 87909a06a8
3 changed files with 154 additions and 63 deletions

View File

@@ -1,15 +1,71 @@
import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
import { memo, useCallback, useMemo } from 'react';
import {
selectGalleryImageMinimumWidth,
selectImageCollectionQueryArgs,
} from 'features/gallery/store/gallerySelectors';
import { memo, useCallback } from 'react';
import { VirtuosoGrid } from 'react-virtuoso';
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
import {
useGetImageCollectionCountsQuery,
useGetImageCollectionQuery,
useLazyGetImageCollectionQuery,
} from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
// Types for range management
type Collection = 'starred' | 'unstarred';
interface RangeKey {
collection: Collection;
offset: number;
limit: number;
}
interface PositionQuery extends RangeKey {
imageIndex: number;
}
type PositionInfo = {
totalCount: number;
starredCount: number;
unstarredCount: number;
starredEnd: number;
};
// Query options factory functions to prevent recreation on every render
const countsQueryOptions = {
selectFromResult: ({ data, isLoading }) => {
const positionInfo: PositionInfo | null = data
? {
totalCount: data.total_count ?? 0,
starredCount: data.starred_count ?? 0,
unstarredCount: data.unstarred_count ?? 0,
starredEnd: (data.starred_count ?? 0) - 1,
}
: null;
return {
positionInfo,
isLoading,
};
},
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
const createImageCollectionQueryOptions = (queryParams: PositionQuery | null) =>
({
skip: !queryParams,
selectFromResult: (result) => {
return {
imageDTO: (queryParams && result.data?.items?.[queryParams.imageIndex]) || null,
};
},
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1];
// Placeholder image component for now
const ImagePlaceholder = memo(({ image }: { image: ImageDTO }) => (
<Image src={image.thumbnail_url} w="full" h="full" objectFit="contain" />
const ImagePlaceholder = memo(({ imageDTO }: { imageDTO: ImageDTO }) => (
<Image src={imageDTO.thumbnail_url} w="full" h="full" objectFit="contain" />
));
ImagePlaceholder.displayName = 'ImagePlaceholder';
@@ -19,30 +75,18 @@ const ImageSkeleton = memo(() => <Skeleton w="full" h="full" />);
ImageSkeleton.displayName = 'ImageSkeleton';
// Hook to manage position calculations and image access
// Hook to manage position calculations and range loading
const useVirtualImageData = () => {
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
// Get total counts for position mapping
const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs);
// Get position info derived from counts using selectFromResult
const { positionInfo, isLoading } = useGetImageCollectionCountsQuery(queryArgs, countsQueryOptions);
// Calculate position mappings
const positionInfo = useMemo(() => {
if (!counts) {
return null;
}
return {
totalCount: counts.total_count,
starredCount: counts.starred_count ?? 0,
unstarredCount: counts.unstarred_count ?? 0,
starredEnd: (counts.starred_count ?? 0) - 1,
};
}, [counts]);
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
// Function to get query params for a specific position
const getQueryParamsForPosition = useCallback(
(index: number) => {
(index: number): PositionQuery | null => {
if (!positionInfo) {
return null;
}
@@ -52,7 +96,7 @@ const useVirtualImageData = () => {
const unstarredOffset = index - positionInfo.starredCount;
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
return {
collection: 'unstarred' as const,
collection: 'unstarred',
offset: rangeOffset,
limit: 50,
imageIndex: unstarredOffset % 50,
@@ -61,7 +105,7 @@ const useVirtualImageData = () => {
// This position is in the starred collection
const rangeOffset = Math.floor(index / 50) * 50;
return {
collection: 'starred' as const,
collection: 'starred',
offset: rangeOffset,
limit: 50,
imageIndex: index % 50,
@@ -71,21 +115,48 @@ const useVirtualImageData = () => {
[positionInfo]
);
// Function to calculate required ranges for a viewport and trigger lazy queries
const updateRequiredRanges = useCallback(
(startIndex: number, endIndex: number) => {
if (!positionInfo) {
return;
}
for (let i = startIndex; i <= endIndex; i++) {
const queryParams = getQueryParamsForPosition(i);
if (queryParams) {
const { collection, offset, limit } = queryParams;
triggerGetImageCollection(
{
collection,
offset,
limit,
...queryArgs,
},
true
);
}
}
},
[positionInfo, getQueryParamsForPosition, triggerGetImageCollection, queryArgs]
);
return {
positionInfo,
countsLoading,
isLoading,
getQueryParamsForPosition,
queryArgs,
updateRequiredRanges,
};
};
// Hook to get image data for a specific position using RTK Query cache
// Hook to get image data for a specific position using selectFromResult
const useImageAtPosition = (index: number) => {
const { getQueryParamsForPosition, queryArgs } = useVirtualImageData();
const queryParams = getQueryParamsForPosition(index);
const { data } = useGetImageCollectionQuery(
const { imageDTO } = useGetImageCollectionQuery(
queryParams
? {
collection: queryParams.collection,
@@ -93,22 +164,19 @@ const useImageAtPosition = (index: number) => {
limit: queryParams.limit,
...queryArgs,
}
: skipToken
: skipToken,
createImageCollectionQueryOptions(queryParams)
);
if (!queryParams || !data?.items) {
return null;
}
return data.items[queryParams.imageIndex] || null;
return imageDTO;
};
// Component to render a single image at a position
const ImageAtPosition = memo(({ index }: { index: number }) => {
const image = useImageAtPosition(index);
const imageDTO = useImageAtPosition(index);
if (image) {
return <ImagePlaceholder image={image} />;
if (imageDTO) {
return <ImagePlaceholder imageDTO={imageDTO} />;
}
return <ImageSkeleton />;
@@ -117,7 +185,15 @@ const ImageAtPosition = memo(({ index }: { index: number }) => {
ImageAtPosition.displayName = 'ImageAtPosition';
export const NewGallery = memo(() => {
const { positionInfo, countsLoading } = useVirtualImageData();
const { positionInfo, isLoading, updateRequiredRanges } = useVirtualImageData();
// Handle range changes from VirtuosoGrid
const handleRangeChanged = useCallback(
(range: { startIndex: number; endIndex: number }) => {
updateRequiredRanges(range.startIndex, range.endIndex);
},
[updateRequiredRanges]
);
// Render item at specific index
const itemContent = useCallback((index: number) => {
@@ -127,7 +203,7 @@ export const NewGallery = memo(() => {
// Compute item key using position index - let RTK Query handle the caching
const computeItemKey = useCallback((index: number) => `position-${index}`, []);
if (countsLoading) {
if (isLoading) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Spinner size="lg" />
@@ -146,10 +222,10 @@ export const NewGallery = memo(() => {
return (
<Box height="100%" width="100%">
{/* Virtualized grid */}
<VirtuosoGrid
totalCount={positionInfo.totalCount}
overscan={200}
increaseViewportBy={1024}
rangeChanged={handleRangeChanged}
itemContent={itemContent}
style={style}
computeItemKey={computeItemKey}
@@ -163,9 +239,19 @@ NewGallery.displayName = 'NewGallery';
const style = { height: '100%', width: '100%' };
const ListComponent = forwardRef((props, ref) => (
<Grid ref={ref} gridTemplateColumns="repeat(auto-fill, minmax(64px, 1fr))" gap={2} padding={2} {...props} />
));
const ListComponent = forwardRef((props, ref) => {
const galleryImageMinimumWidth = useAppSelector(selectGalleryImageMinimumWidth);
return (
<Grid
ref={ref}
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
gap={2}
padding={2}
{...props}
/>
);
});
const ItemComponent = forwardRef((props, ref) => <GridItem ref={ref} aspectRatio="1/1" {...props} />);

View File

@@ -26,7 +26,8 @@ import { buildBoardsUrl } from './boards';
* buildImagesUrl('some-path')
* // '/api/v1/images/some-path'
*/
const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`);
const buildImagesUrl = (path: string = '', query?: Parameters<typeof buildV1Url>[1]) =>
buildV1Url(`images/${path}`, query);
/**
* Builds an endpoint URL for the board_images router
@@ -428,9 +429,8 @@ export const imagesApi = api.injectEndpoints({
paths['/api/v1/images/collections/counts']['get']['parameters']['query']
>({
query: (queryArgs) => ({
url: buildImagesUrl('collections/counts'),
url: buildImagesUrl('collections/counts', queryArgs),
method: 'GET',
params: queryArgs,
}),
providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'],
}),
@@ -443,28 +443,27 @@ export const imagesApi = api.injectEndpoints({
paths['/api/v1/images/collections/{collection}']['get']['parameters']['query']
>({
query: ({ collection, ...queryArgs }) => ({
url: buildImagesUrl(`collections/${collection}`),
url: buildImagesUrl(`collections/${collection}`, queryArgs),
method: 'GET',
params: queryArgs,
}),
providesTags: (result, error, { collection, board_id, categories }) => {
const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`;
return [{ type: 'ImageCollection', id: cacheKey }, 'FetchOnReconnect'];
},
async onQueryStarted(_, { dispatch, queryFulfilled }) {
// Populate the getImageDTO cache with these images, similar to listImages
const res = await queryFulfilled;
const imageDTOs = res.data.items;
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
for (const imageDTO of imageDTOs) {
updates.push({
endpointName: 'getImageDTO',
arg: imageDTO.image_name,
value: imageDTO,
});
}
dispatch(imagesApi.util.upsertQueryEntries(updates));
},
// async onQueryStarted(_, { dispatch, queryFulfilled }) {
// // Populate the getImageDTO cache with these images, similar to listImages
// const res = await queryFulfilled;
// const imageDTOs = res.data.items;
// const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
// for (const imageDTO of imageDTOs) {
// updates.push({
// endpointName: 'getImageDTO',
// arg: imageDTO.image_name,
// value: imageDTO,
// });
// }
// dispatch(imagesApi.util.upsertQueryEntries(updates));
// },
}),
}),
});

View File

@@ -10,6 +10,7 @@ import { buildCreateApi, coreModule, fetchBaseQuery, reactHooksModule } from '@r
import { $authToken } from 'app/store/nanostores/authToken';
import { $baseUrl } from 'app/store/nanostores/baseUrl';
import { $projectId } from 'app/store/nanostores/projectId';
import queryString from 'query-string';
const tagTypes = [
'AppVersion',
@@ -133,5 +134,10 @@ function getCircularReplacer() {
};
}
export const buildV1Url = (path: string): string => `api/v1/${path}`;
export const buildV1Url = (path: string, query?: Parameters<typeof queryString.stringify>[0]): string => {
if (!query) {
return `api/v1/${path}`;
}
return `api/v1/${path}?${queryString.stringify(query)}`;
};
export const buildV2Url = (path: string): string => `api/v2/${path}`;