From 87909a06a8d2f8bc75c98b11048d9efba0c3944c Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 24 Jun 2025 17:23:04 +1000 Subject: [PATCH] refactor: gallery scroll (improved impl) --- .../gallery/components/NewGallery.tsx | 172 +++++++++++++----- .../web/src/services/api/endpoints/images.ts | 37 ++-- .../frontend/web/src/services/api/index.ts | 8 +- 3 files changed, 154 insertions(+), 63 deletions(-) diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index 985a8a6427..9837db60d0 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -1,15 +1,71 @@ import { Box, Flex, forwardRef, Grid, GridItem, Image, Skeleton, Spinner, Text } from '@invoke-ai/ui-library'; import { skipToken } from '@reduxjs/toolkit/query'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectImageCollectionQueryArgs } from 'features/gallery/store/gallerySelectors'; -import { memo, useCallback, useMemo } from 'react'; +import { + selectGalleryImageMinimumWidth, + selectImageCollectionQueryArgs, +} from 'features/gallery/store/gallerySelectors'; +import { memo, useCallback } from 'react'; import { VirtuosoGrid } from 'react-virtuoso'; -import { useGetImageCollectionCountsQuery, useGetImageCollectionQuery } from 'services/api/endpoints/images'; +import { + useGetImageCollectionCountsQuery, + useGetImageCollectionQuery, + useLazyGetImageCollectionQuery, +} from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; +// Types for range management +type Collection = 'starred' | 'unstarred'; + +interface RangeKey { + collection: Collection; + offset: number; + limit: number; +} + +interface PositionQuery extends RangeKey { + imageIndex: number; +} + +type PositionInfo = { + totalCount: number; + starredCount: number; + unstarredCount: number; + starredEnd: number; +}; + +// Query options factory functions to prevent recreation on every render +const countsQueryOptions = { + selectFromResult: ({ data, isLoading }) => { + const positionInfo: PositionInfo | null = data + ? { + totalCount: data.total_count ?? 0, + starredCount: data.starred_count ?? 0, + unstarredCount: data.unstarred_count ?? 0, + starredEnd: (data.starred_count ?? 0) - 1, + } + : null; + + return { + positionInfo, + isLoading, + }; + }, +} satisfies Parameters[1]; + +const createImageCollectionQueryOptions = (queryParams: PositionQuery | null) => + ({ + skip: !queryParams, + selectFromResult: (result) => { + return { + imageDTO: (queryParams && result.data?.items?.[queryParams.imageIndex]) || null, + }; + }, + }) satisfies Parameters[1]; + // Placeholder image component for now -const ImagePlaceholder = memo(({ image }: { image: ImageDTO }) => ( - +const ImagePlaceholder = memo(({ imageDTO }: { imageDTO: ImageDTO }) => ( + )); ImagePlaceholder.displayName = 'ImagePlaceholder'; @@ -19,30 +75,18 @@ const ImageSkeleton = memo(() => ); ImageSkeleton.displayName = 'ImageSkeleton'; -// Hook to manage position calculations and image access +// Hook to manage position calculations and range loading const useVirtualImageData = () => { const queryArgs = useAppSelector(selectImageCollectionQueryArgs); - // Get total counts for position mapping - const { data: counts, isLoading: countsLoading } = useGetImageCollectionCountsQuery(queryArgs); + // Get position info derived from counts using selectFromResult + const { positionInfo, isLoading } = useGetImageCollectionCountsQuery(queryArgs, countsQueryOptions); - // Calculate position mappings - const positionInfo = useMemo(() => { - if (!counts) { - return null; - } - - return { - totalCount: counts.total_count, - starredCount: counts.starred_count ?? 0, - unstarredCount: counts.unstarred_count ?? 0, - starredEnd: (counts.starred_count ?? 0) - 1, - }; - }, [counts]); + const [triggerGetImageCollection] = useLazyGetImageCollectionQuery(); // Function to get query params for a specific position const getQueryParamsForPosition = useCallback( - (index: number) => { + (index: number): PositionQuery | null => { if (!positionInfo) { return null; } @@ -52,7 +96,7 @@ const useVirtualImageData = () => { const unstarredOffset = index - positionInfo.starredCount; const rangeOffset = Math.floor(unstarredOffset / 50) * 50; return { - collection: 'unstarred' as const, + collection: 'unstarred', offset: rangeOffset, limit: 50, imageIndex: unstarredOffset % 50, @@ -61,7 +105,7 @@ const useVirtualImageData = () => { // This position is in the starred collection const rangeOffset = Math.floor(index / 50) * 50; return { - collection: 'starred' as const, + collection: 'starred', offset: rangeOffset, limit: 50, imageIndex: index % 50, @@ -71,21 +115,48 @@ const useVirtualImageData = () => { [positionInfo] ); + // Function to calculate required ranges for a viewport and trigger lazy queries + const updateRequiredRanges = useCallback( + (startIndex: number, endIndex: number) => { + if (!positionInfo) { + return; + } + + for (let i = startIndex; i <= endIndex; i++) { + const queryParams = getQueryParamsForPosition(i); + if (queryParams) { + const { collection, offset, limit } = queryParams; + triggerGetImageCollection( + { + collection, + offset, + limit, + ...queryArgs, + }, + true + ); + } + } + }, + [positionInfo, getQueryParamsForPosition, triggerGetImageCollection, queryArgs] + ); + return { positionInfo, - countsLoading, + isLoading, getQueryParamsForPosition, queryArgs, + updateRequiredRanges, }; }; -// Hook to get image data for a specific position using RTK Query cache +// Hook to get image data for a specific position using selectFromResult const useImageAtPosition = (index: number) => { const { getQueryParamsForPosition, queryArgs } = useVirtualImageData(); const queryParams = getQueryParamsForPosition(index); - const { data } = useGetImageCollectionQuery( + const { imageDTO } = useGetImageCollectionQuery( queryParams ? { collection: queryParams.collection, @@ -93,22 +164,19 @@ const useImageAtPosition = (index: number) => { limit: queryParams.limit, ...queryArgs, } - : skipToken + : skipToken, + createImageCollectionQueryOptions(queryParams) ); - if (!queryParams || !data?.items) { - return null; - } - - return data.items[queryParams.imageIndex] || null; + return imageDTO; }; // Component to render a single image at a position const ImageAtPosition = memo(({ index }: { index: number }) => { - const image = useImageAtPosition(index); + const imageDTO = useImageAtPosition(index); - if (image) { - return ; + if (imageDTO) { + return ; } return ; @@ -117,7 +185,15 @@ const ImageAtPosition = memo(({ index }: { index: number }) => { ImageAtPosition.displayName = 'ImageAtPosition'; export const NewGallery = memo(() => { - const { positionInfo, countsLoading } = useVirtualImageData(); + const { positionInfo, isLoading, updateRequiredRanges } = useVirtualImageData(); + + // Handle range changes from VirtuosoGrid + const handleRangeChanged = useCallback( + (range: { startIndex: number; endIndex: number }) => { + updateRequiredRanges(range.startIndex, range.endIndex); + }, + [updateRequiredRanges] + ); // Render item at specific index const itemContent = useCallback((index: number) => { @@ -127,7 +203,7 @@ export const NewGallery = memo(() => { // Compute item key using position index - let RTK Query handle the caching const computeItemKey = useCallback((index: number) => `position-${index}`, []); - if (countsLoading) { + if (isLoading) { return ( @@ -146,10 +222,10 @@ export const NewGallery = memo(() => { return ( - {/* Virtualized grid */} ( - -)); +const ListComponent = forwardRef((props, ref) => { + const galleryImageMinimumWidth = useAppSelector(selectGalleryImageMinimumWidth); + + return ( + + ); +}); const ItemComponent = forwardRef((props, ref) => ); diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index f3d867b7d3..23d5de493e 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -26,7 +26,8 @@ import { buildBoardsUrl } from './boards'; * buildImagesUrl('some-path') * // '/api/v1/images/some-path' */ -const buildImagesUrl = (path: string = '') => buildV1Url(`images/${path}`); +const buildImagesUrl = (path: string = '', query?: Parameters[1]) => + buildV1Url(`images/${path}`, query); /** * Builds an endpoint URL for the board_images router @@ -428,9 +429,8 @@ export const imagesApi = api.injectEndpoints({ paths['/api/v1/images/collections/counts']['get']['parameters']['query'] >({ query: (queryArgs) => ({ - url: buildImagesUrl('collections/counts'), + url: buildImagesUrl('collections/counts', queryArgs), method: 'GET', - params: queryArgs, }), providesTags: ['ImageCollectionCounts', 'FetchOnReconnect'], }), @@ -443,28 +443,27 @@ export const imagesApi = api.injectEndpoints({ paths['/api/v1/images/collections/{collection}']['get']['parameters']['query'] >({ query: ({ collection, ...queryArgs }) => ({ - url: buildImagesUrl(`collections/${collection}`), + url: buildImagesUrl(`collections/${collection}`, queryArgs), method: 'GET', - params: queryArgs, }), providesTags: (result, error, { collection, board_id, categories }) => { const cacheKey = `${collection}-${board_id || 'all'}-${categories?.join(',') || 'all'}`; return [{ type: 'ImageCollection', id: cacheKey }, 'FetchOnReconnect']; }, - async onQueryStarted(_, { dispatch, queryFulfilled }) { - // Populate the getImageDTO cache with these images, similar to listImages - const res = await queryFulfilled; - const imageDTOs = res.data.items; - const updates: Param0 = []; - for (const imageDTO of imageDTOs) { - updates.push({ - endpointName: 'getImageDTO', - arg: imageDTO.image_name, - value: imageDTO, - }); - } - dispatch(imagesApi.util.upsertQueryEntries(updates)); - }, + // async onQueryStarted(_, { dispatch, queryFulfilled }) { + // // Populate the getImageDTO cache with these images, similar to listImages + // const res = await queryFulfilled; + // const imageDTOs = res.data.items; + // const updates: Param0 = []; + // for (const imageDTO of imageDTOs) { + // updates.push({ + // endpointName: 'getImageDTO', + // arg: imageDTO.image_name, + // value: imageDTO, + // }); + // } + // dispatch(imagesApi.util.upsertQueryEntries(updates)); + // }, }), }), }); diff --git a/invokeai/frontend/web/src/services/api/index.ts b/invokeai/frontend/web/src/services/api/index.ts index b2b65b4d99..a5d3ddbec8 100644 --- a/invokeai/frontend/web/src/services/api/index.ts +++ b/invokeai/frontend/web/src/services/api/index.ts @@ -10,6 +10,7 @@ import { buildCreateApi, coreModule, fetchBaseQuery, reactHooksModule } from '@r import { $authToken } from 'app/store/nanostores/authToken'; import { $baseUrl } from 'app/store/nanostores/baseUrl'; import { $projectId } from 'app/store/nanostores/projectId'; +import queryString from 'query-string'; const tagTypes = [ 'AppVersion', @@ -133,5 +134,10 @@ function getCircularReplacer() { }; } -export const buildV1Url = (path: string): string => `api/v1/${path}`; +export const buildV1Url = (path: string, query?: Parameters[0]): string => { + if (!query) { + return `api/v1/${path}`; + } + return `api/v1/${path}?${queryString.stringify(query)}`; +}; export const buildV2Url = (path: string): string => `api/v2/${path}`;