mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 22:35:27 -05:00
refactor: gallery scroll (improved impl)
This commit is contained in:
@@ -1,15 +1,71 @@
|
||||
import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import {
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectImageCollectionQueryArgs,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images';
|
||||
import {
|
||||
useGetImageCollectionCountsQuery,
|
||||
useGetImageCollectionQuery,
|
||||
useLazyGetImageCollectionQuery,
|
||||
} from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
// Types for range management
|
||||
type Collection = 'starred' | 'unstarred';
|
||||
|
||||
interface RangeKey {
|
||||
collection: Collection;
|
||||
offset: number;
|
||||
limit: number;
|
||||
}
|
||||
|
||||
interface PositionQuery extends RangeKey {
|
||||
imageIndex: number;
|
||||
}
|
||||
|
||||
type PositionInfo = {
|
||||
totalCount: number;
|
||||
starredCount: number;
|
||||
unstarredCount: number;
|
||||
starredEnd: number;
|
||||
};
|
||||
|
||||
// Query options factory functions to prevent recreation on every render
|
||||
const countsQueryOptions = {
|
||||
selectFromResult: ({ data, isLoading }) => {
|
||||
const positionInfo: PositionInfo | null = data
|
||||
? {
|
||||
totalCount: data.total_count ?? 0,
|
||||
starredCount: data.starred_count ?? 0,
|
||||
unstarredCount: data.unstarred_count ?? 0,
|
||||
starredEnd: (data.starred_count ?? 0) - 1,
|
||||
}
|
||||
: null;
|
||||
|
||||
return {
|
||||
positionInfo,
|
||||
isLoading,
|
||||
};
|
||||
},
|
||||
} satisfies Parameters<typeof useGetImageCollectionCountsQuery>[1];
|
||||
|
||||
const createImageCollectionQueryOptions = (queryParams: PositionQuery | null) =>
|
||||
({
|
||||
skip: !queryParams,
|
||||
selectFromResult: (result) => {
|
||||
return {
|
||||
imageDTO: (queryParams && result.data?.items?.[queryParams.imageIndex]) || null,
|
||||
};
|
||||
},
|
||||
}) satisfies Parameters<typeof useGetImageCollectionQuery>[1];
|
||||
|
||||
// Placeholder image component for now
|
||||
const ImagePlaceholder = memo(({ image }: { image: ImageDTO }) => (
|
||||
<Image src={image.thumbnail_url} w="full" h="full" objectFit="contain" />
|
||||
const ImagePlaceholder = memo(({ imageDTO }: { imageDTO: ImageDTO }) => (
|
||||
<Image src={imageDTO.thumbnail_url} w="full" h="full" objectFit="contain" />
|
||||
));
|
||||
|
||||
ImagePlaceholder.displayName = 'ImagePlaceholder';
|
||||
@@ -19,30 +75,18 @@ const ImageSkeleton = memo(() => <Skeleton w="full" h="full" />);
|
||||
|
||||
ImageSkeleton.displayName = 'ImageSkeleton';
|
||||
|
||||
// Hook to manage position calculations and image access
|
||||
// Hook to manage position calculations and range loading
|
||||
const useVirtualImageData = () => {
|
||||
const queryArgs = useAppSelector(selectImageCollectionQueryArgs);
|
||||
|
||||
// Get total counts for position mapping
|
||||
const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs);
|
||||
// Get position info derived from counts using selectFromResult
|
||||
const { positionInfo, isLoading } = useGetImageCollectionCountsQuery(queryArgs, countsQueryOptions);
|
||||
|
||||
// Calculate position mappings
|
||||
const positionInfo = useMemo(() => {
|
||||
if (!counts) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
totalCount: counts.total_count,
|
||||
starredCount: counts.starred_count ?? 0,
|
||||
unstarredCount: counts.unstarred_count ?? 0,
|
||||
starredEnd: (counts.starred_count ?? 0) - 1,
|
||||
};
|
||||
}, [counts]);
|
||||
const [triggerGetImageCollection] = useLazyGetImageCollectionQuery();
|
||||
|
||||
// Function to get query params for a specific position
|
||||
const getQueryParamsForPosition = useCallback(
|
||||
(index: number) => {
|
||||
(index: number): PositionQuery | null => {
|
||||
if (!positionInfo) {
|
||||
return null;
|
||||
}
|
||||
@@ -52,7 +96,7 @@ const useVirtualImageData = () => {
|
||||
const unstarredOffset = index - positionInfo.starredCount;
|
||||
const rangeOffset = Math.floor(unstarredOffset / 50) * 50;
|
||||
return {
|
||||
collection: 'unstarred' as const,
|
||||
collection: 'unstarred',
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
imageIndex: unstarredOffset % 50,
|
||||
@@ -61,7 +105,7 @@ const useVirtualImageData = () => {
|
||||
// This position is in the starred collection
|
||||
const rangeOffset = Math.floor(index / 50) * 50;
|
||||
return {
|
||||
collection: 'starred' as const,
|
||||
collection: 'starred',
|
||||
offset: rangeOffset,
|
||||
limit: 50,
|
||||
imageIndex: index % 50,
|
||||
@@ -71,21 +115,48 @@ const useVirtualImageData = () => {
|
||||
[positionInfo]
|
||||
);
|
||||
|
||||
// Function to calculate required ranges for a viewport and trigger lazy queries
|
||||
const updateRequiredRanges = useCallback(
|
||||
(startIndex: number, endIndex: number) => {
|
||||
if (!positionInfo) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (let i = startIndex; i <= endIndex; i++) {
|
||||
const queryParams = getQueryParamsForPosition(i);
|
||||
if (queryParams) {
|
||||
const { collection, offset, limit } = queryParams;
|
||||
triggerGetImageCollection(
|
||||
{
|
||||
collection,
|
||||
offset,
|
||||
limit,
|
||||
...queryArgs,
|
||||
},
|
||||
true
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
[positionInfo, getQueryParamsForPosition, triggerGetImageCollection, queryArgs]
|
||||
);
|
||||
|
||||
return {
|
||||
positionInfo,
|
||||
countsLoading,
|
||||
isLoading,
|
||||
getQueryParamsForPosition,
|
||||
queryArgs,
|
||||
updateRequiredRanges,
|
||||
};
|
||||
};
|
||||
|
||||
// Hook to get image data for a specific position using RTK Query cache
|
||||
// Hook to get image data for a specific position using selectFromResult
|
||||
const useImageAtPosition = (index: number) => {
|
||||
const { getQueryParamsForPosition, queryArgs } = useVirtualImageData();
|
||||
|
||||
const queryParams = getQueryParamsForPosition(index);
|
||||
|
||||
const { data } = useGetImageCollectionQuery(
|
||||
const { imageDTO } = useGetImageCollectionQuery(
|
||||
queryParams
|
||||
? {
|
||||
collection: queryParams.collection,
|
||||
@@ -93,22 +164,19 @@ const useImageAtPosition = (index: number) => {
|
||||
limit: queryParams.limit,
|
||||
...queryArgs,
|
||||
}
|
||||
: skipToken
|
||||
: skipToken,
|
||||
createImageCollectionQueryOptions(queryParams)
|
||||
);
|
||||
|
||||
if (!queryParams || !data?.items) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return data.items[queryParams.imageIndex] || null;
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
// Component to render a single image at a position
|
||||
const ImageAtPosition = memo(({ index }: { index: number }) => {
|
||||
const image = useImageAtPosition(index);
|
||||
const imageDTO = useImageAtPosition(index);
|
||||
|
||||
if (image) {
|
||||
return <ImagePlaceholder image={image} />;
|
||||
if (imageDTO) {
|
||||
return <ImagePlaceholder imageDTO={imageDTO} />;
|
||||
}
|
||||
|
||||
return <ImageSkeleton />;
|
||||
@@ -117,7 +185,15 @@ const ImageAtPosition = memo(({ index }: { index: number }) => {
|
||||
ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
export const NewGallery = memo(() => {
|
||||
const { positionInfo, countsLoading } = useVirtualImageData();
|
||||
const { positionInfo, isLoading, updateRequiredRanges } = useVirtualImageData();
|
||||
|
||||
// Handle range changes from VirtuosoGrid
|
||||
const handleRangeChanged = useCallback(
|
||||
(range: { startIndex: number; endIndex: number }) => {
|
||||
updateRequiredRanges(range.startIndex, range.endIndex);
|
||||
},
|
||||
[updateRequiredRanges]
|
||||
);
|
||||
|
||||
// Render item at specific index
|
||||
const itemContent = useCallback((index: number) => {
|
||||
@@ -127,7 +203,7 @@ export const NewGallery = memo(() => {
|
||||
// Compute item key using position index - let RTK Query handle the caching
|
||||
const computeItemKey = useCallback((index: number) => `position-${index}`, []);
|
||||
|
||||
if (countsLoading) {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex height="100%" alignItems="center" justifyContent="center">
|
||||
<Spinner size="lg" />
|
||||
@@ -146,10 +222,10 @@ export const NewGallery = memo(() => {
|
||||
|
||||
return (
|
||||
<Box height="100%" width="100%">
|
||||
{/* Virtualized grid */}
|
||||
<VirtuosoGrid
|
||||
totalCount={positionInfo.totalCount}
|
||||
overscan={200}
|
||||
increaseViewportBy={1024}
|
||||
rangeChanged={handleRangeChanged}
|
||||
itemContent={itemContent}
|
||||
style={style}
|
||||
computeItemKey={computeItemKey}
|
||||
@@ -163,9 +239,19 @@ NewGallery.displayName = 'NewGallery';
|
||||
|
||||
const style = { height: '100%', width: '100%' };
|
||||
|
||||
const ListComponent = forwardRef((props, ref) => (
|
||||
<Grid ref={ref} gridTemplateColumns="repeat(auto-fill, minmax(64px, 1fr))" gap={2} padding={2} {...props} />
|
||||
));
|
||||
const ListComponent = forwardRef((props, ref) => {
|
||||
const galleryImageMinimumWidth = useAppSelector(selectGalleryImageMinimumWidth);
|
||||
|
||||
return (
|
||||
<Grid
|
||||
ref={ref}
|
||||
gridTemplateColumns={`repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`}
|
||||
gap={2}
|
||||
padding={2}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
const ItemComponent = forwardRef((props, ref) => <GridItem ref={ref} aspectRatio="1/1" {...props} />);
|
||||
|
||||
|
||||
@@ -26,7 +26,8 @@ import { buildBoardsUrl } from './boards';
|
||||
* buildImagesUrl('some-path')
|
||||
* // '/api/v1/images/some-path'
|
||||
*/
|
||||
const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`);
|
||||
const buildImagesUrl = (path: string = '', query?: Parameters<typeof buildV1Url>[1]) =>
|
||||
buildV1Url(`images/${path}`, query);
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the board_images router
|
||||
@@ -428,9 +429,8 @@ export const imagesApi = api.injectEndpoints({
|
||||
paths['/api/v1/images/collections/counts']['get']['parameters']['query']
|
||||
>({
|
||||
query: (queryArgs) => ({
|
||||
url: buildImagesUrl('collections/counts'),
|
||||
url: buildImagesUrl('collections/counts', queryArgs),
|
||||
method: 'GET',
|
||||
params: queryArgs,
|
||||
}),
|
||||
providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'],
|
||||
}),
|
||||
@@ -443,28 +443,27 @@ export const imagesApi = api.injectEndpoints({
|
||||
paths['/api/v1/images/collections/{collection}']['get']['parameters']['query']
|
||||
>({
|
||||
query: ({ collection, ...queryArgs }) => ({
|
||||
url: buildImagesUrl(`collections/${collection}`),
|
||||
url: buildImagesUrl(`collections/${collection}`, queryArgs),
|
||||
method: 'GET',
|
||||
params: queryArgs,
|
||||
}),
|
||||
providesTags: (result, error, { collection, board_id, categories }) => {
|
||||
const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`;
|
||||
return [{ type: 'ImageCollection', id: cacheKey }, 'FetchOnReconnect'];
|
||||
},
|
||||
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
// Populate the getImageDTO cache with these images, similar to listImages
|
||||
const res = await queryFulfilled;
|
||||
const imageDTOs = res.data.items;
|
||||
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
for (const imageDTO of imageDTOs) {
|
||||
updates.push({
|
||||
endpointName: 'getImageDTO',
|
||||
arg: imageDTO.image_name,
|
||||
value: imageDTO,
|
||||
});
|
||||
}
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
},
|
||||
// async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
// // Populate the getImageDTO cache with these images, similar to listImages
|
||||
// const res = await queryFulfilled;
|
||||
// const imageDTOs = res.data.items;
|
||||
// const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
// for (const imageDTO of imageDTOs) {
|
||||
// updates.push({
|
||||
// endpointName: 'getImageDTO',
|
||||
// arg: imageDTO.image_name,
|
||||
// value: imageDTO,
|
||||
// });
|
||||
// }
|
||||
// dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
// },
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -10,6 +10,7 @@ import { buildCreateApi, coreModule, fetchBaseQuery, reactHooksModule } from '@r
|
||||
import { $authToken } from 'app/store/nanostores/authToken';
|
||||
import { $baseUrl } from 'app/store/nanostores/baseUrl';
|
||||
import { $projectId } from 'app/store/nanostores/projectId';
|
||||
import queryString from 'query-string';
|
||||
|
||||
const tagTypes = [
|
||||
'AppVersion',
|
||||
@@ -133,5 +134,10 @@ function getCircularReplacer() {
|
||||
};
|
||||
}
|
||||
|
||||
export const buildV1Url = (path: string): string => `api/v1/${path}`;
|
||||
export const buildV1Url = (path: string, query?: Parameters<typeof queryString.stringify>[0]): string => {
|
||||
if (!query) {
|
||||
return `api/v1/${path}`;
|
||||
}
|
||||
return `api/v1/${path}?${queryString.stringify(query)}`;
|
||||
};
|
||||
export const buildV2Url = (path: string): string => `api/v2/${path}`;
|
||||
|
||||
Reference in New Issue
Block a user