diff --git a/invokeai/app/services/images/images_base.py b/invokeai/app/services/images/images_base.py index 4886d31cca..037d463e33 100644 --- a/invokeai/app/services/images/images_base.py +++ b/invokeai/app/services/images/images_base.py @@ -126,7 +126,7 @@ class ImageServiceABC(ABC): board_id: Optional[str] = None, search_term: Optional[str] = None, ) -> OffsetPaginatedResults[ImageDTO]: - """Gets a paginated list of image DTOs.""" + """Gets a paginated list of image DTOs with starred images first when starred_first=True.""" pass @abstractmethod diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index d1a029f1b6..ec757494f5 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -170,8 +170,6 @@ export const createStore = (uniqueStoreKey?: string, persist = true) => reducer: rememberedRootReducer, middleware: (getDefaultMiddleware) => getDefaultMiddleware({ - // serializableCheck: false, - // immutableCheck: false, serializableCheck: import.meta.env.MODE === 'development', immutableCheck: import.meta.env.MODE === 'development', }) diff --git a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx index a663b2514d..ae1d104763 100644 --- a/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx +++ b/invokeai/frontend/web/src/features/gallery/components/NewGallery.tsx @@ -8,6 +8,7 @@ import { } from 'features/gallery/store/gallerySelectors'; import { selectionChanged } from 'features/gallery/store/gallerySlice'; import { useOverlayScrollbars } from 'overlayscrollbars-react'; +import type { MutableRefObject } from 'react'; import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import type { GridComponents, @@ -18,120 +19,64 @@ import type { VirtuosoGridHandle, } from 'react-virtuoso'; import { VirtuosoGrid } from 'react-virtuoso'; -import { - useGetImageCollectionCountsQuery, - useGetImageCollectionQuery, - useGetImageNamesQuery, - useLazyGetImageCollectionQuery, -} from 'services/api/endpoints/images'; -import type { ImageCategory, ImageDTO, SQLiteDirection } from 'services/api/types'; -import { objectEntries } from 'tsafe'; +import { useGetImageNamesQuery, useListImagesQuery } from 'services/api/endpoints/images'; +import type { ImageDTO, ListImagesArgs } from 'services/api/types'; import { useDebounce } from 'use-debounce'; import { GalleryImage } from './ImageGrid/GalleryImage'; const log = logger('gallery'); -// Type for image collection query arguments -type ImageCollectionQueryArgs = { - board_id?: string; - categories?: ImageCategory[]; - search_term?: string; - order_dir?: SQLiteDirection; - is_intermediate: boolean; -}; - // Constants -const RANGE_SIZE = 50; +const PAGE_SIZE = 100; const VIEWPORT_BUFFER = 2048; -const SCROLL_SEEK_VELOCITY_THRESHOLD = 2048; +const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096; const DEBOUNCE_DELAY = 500; -const GRID_GAP = 2; +const GRID_GAP = 1; const SPINNER_OPACITY = 0.3; type GridContext = { - queryArgs: ImageCollectionQueryArgs; + queryArgs: ListImagesArgs; imageNames: string[]; - starredCount: number; }; -type PositionInfo = { - collection: 'starred' | 'unstarred'; - offset: number; - itemIndex: number; +export const useDebouncedImageCollectionQueryArgs = () => { + const _galleryQueryArgs = useAppSelector(selectImageCollectionQueryArgs); + const [queryArgs] = useDebounce(_galleryQueryArgs, DEBOUNCE_DELAY); + return queryArgs; }; -// Helper to calculate which collection and range an index belongs to -const getPositionInfo = (index: number, starredCount: number): PositionInfo => { - if (index < starredCount) { - // Starred collection - const offset = Math.floor(index / RANGE_SIZE) * RANGE_SIZE; - return { - collection: 'starred', - offset, - itemIndex: index - offset, - }; - } else { - // Unstarred collection - const unstarredIndex = index - starredCount; - const offset = Math.floor(unstarredIndex / RANGE_SIZE) * RANGE_SIZE; - return { - collection: 'unstarred', - offset, - itemIndex: unstarredIndex - offset, - }; - } -}; - -// Hook to get image DTO from batched collection data -const useImageFromBatch = ( - imageName: string, - index: number, - starredCount: number, - queryArgs: ImageCollectionQueryArgs -): ImageDTO | null => { +// Hook to get an image DTO from cache or trigger loading +const useImageDTOFromListQuery = (index: number, imageName: string, queryArgs: ListImagesArgs): ImageDTO | null => { const { arg, options } = useMemo(() => { - const positionInfo = getPositionInfo(index, starredCount); + const pageOffset = Math.floor(index / PAGE_SIZE) * PAGE_SIZE; + return { + arg: { + ...queryArgs, + offset: pageOffset, + limit: PAGE_SIZE, + } satisfies Parameters[0], + options: { + selectFromResult: ({ data }) => { + const imageDTO = data?.items?.[index - pageOffset] || null; + if (imageDTO && imageDTO.image_name !== imageName) { + log.warn(`Image at index ${index} does not match expected image name ${imageName}`); + } + return { imageDTO }; + }, + } satisfies Parameters[1], + }; + }, [index, queryArgs, imageName]); - const arg = { - collection: positionInfo.collection, - offset: positionInfo.offset, - limit: RANGE_SIZE, - ...queryArgs, - } satisfies Parameters[0]; - - const options = { - selectFromResult: ({ data }) => { - const imageDTO = data?.items?.[positionInfo.itemIndex] || null; - if (imageDTO && imageDTO.image_name !== imageName) { - log.warn(`Image name mismatch at index ${index}: expected ${imageName}, got ${imageDTO.image_name}`); - } - return { imageDTO }; - }, - } satisfies Parameters[1]; - - return { arg, options }; - }, [imageName, index, queryArgs, starredCount]); - - const { imageDTO } = useGetImageCollectionQuery(arg, options); + const { imageDTO } = useListImagesQuery(arg, options); return imageDTO; }; -// Individual image component that gets its data from batched requests +// Individual image component that gets its data from RTK Query cache const ImageAtPosition = memo( - ({ - imageName, - index, - starredCount, - queryArgs, - }: { - imageName: string; - index: number; - starredCount: number; - queryArgs: ImageCollectionQueryArgs; - }) => { - const imageDTO = useImageFromBatch(imageName, index, starredCount, queryArgs); + ({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImagesArgs }) => { + const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs); if (!imageDTO) { return ; @@ -142,58 +87,9 @@ const ImageAtPosition = memo( ); ImageAtPosition.displayName = 'ImageAtPosition'; -export const useDebouncedImageCollectionQueryArgs = () => { - const _queryArgs = useAppSelector(selectImageCollectionQueryArgs); - const [queryArgs] = useDebounce(_queryArgs, DEBOUNCE_DELAY); - return queryArgs; -}; - -// Memoized item content function that uses image names as data but batches requests -const itemContent: GridItemContent = (index, imageName, { queryArgs, starredCount }) => { - if (!imageName) { - return ; - } - return ; -}; - // Memoized compute key function using image names const computeItemKey: GridComputeItemKey = (index, imageName, { queryArgs }) => { - return `${JSON.stringify(queryArgs)}-${imageName || index}`; -}; - -// Hook to prefetch ranges based on visible area -const usePrefetchRanges = (starredCount: number, queryArgs: ImageCollectionQueryArgs) => { - const [triggerGetImageCollection] = useLazyGetImageCollectionQuery(); - - const prefetchRange = useCallback( - (startIndex: number, endIndex: number) => { - const ranges = { - starred: new Set(), - unstarred: new Set(), - }; - - // Collect all unique ranges needed for the visible area - for (let i = startIndex; i <= endIndex; i++) { - const positionInfo = getPositionInfo(i, starredCount); - ranges[positionInfo.collection].add(positionInfo.offset); - } - - // Trigger queries for each unique range - for (const [collection, offsets] of objectEntries(ranges)) { - for (const offset of offsets) { - triggerGetImageCollection({ - collection, - offset, - limit: RANGE_SIZE, - ...queryArgs, - }); - } - } - }, - [starredCount, queryArgs, triggerGetImageCollection] - ); - - return prefetchRange; + return `${JSON.stringify(queryArgs)}-${imageName}`; }; // Physical DOM-based grid calculation using refs (based on working old implementation) @@ -241,40 +137,72 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => { }; // Check if an item at a given index is visible in the viewport -const isItemVisible = (index: number, rootEl: HTMLDivElement): null | 'start' | 'center' | 'end' => { +const scrollIntoView = ( + index: number, + rootEl: HTMLDivElement, + virtuosoGridHandle: VirtuosoGridHandle, + range: ListRange +) => { + if (range.endIndex === 0) { + return; + } + // First get the virtuoso grid list root element const gridList = rootEl.querySelector('.virtuoso-grid-list') as HTMLElement; if (!gridList) { - return null; + // No grid - cannot scroll! + return; } // Then find the specific item within the grid list const targetItem = gridList.querySelector(`.virtuoso-grid-item[data-index="${index}"]`) as HTMLElement; if (!targetItem) { - return null; + if (index > range.endIndex) { + virtuosoGridHandle.scrollToIndex({ + index, + behavior: 'auto', + align: 'start', + }); + } else if (index < range.startIndex) { + virtuosoGridHandle.scrollToIndex({ + index, + behavior: 'auto', + align: 'end', + }); + } else { + log.warn(`Unable to find item index ${index} but it is in range ${range.startIndex}-${range.endIndex}`); + } + return; } const itemRect = targetItem.getBoundingClientRect(); const rootRect = rootEl.getBoundingClientRect(); if (itemRect.top < rootRect.top) { - return 'start'; + virtuosoGridHandle.scrollToIndex({ + index, + behavior: 'auto', + align: 'start', + }); + } else if (itemRect.bottom > rootRect.bottom) { + virtuosoGridHandle.scrollToIndex({ + index, + behavior: 'auto', + align: 'end', + }); } - if (itemRect.bottom > rootRect.bottom) { - return 'end'; - } - - return 'center'; + return; }; // Hook for keyboard navigation using physical DOM measurements const useKeyboardNavigation = ( imageNames: string[], virtuosoRef: React.RefObject, - rootRef: React.RefObject + rootRef: React.RefObject, + rangeRef: MutableRefObject ) => { const dispatch = useAppDispatch(); const lastSelectedImage = useAppSelector(selectLastSelectedImage); @@ -291,7 +219,9 @@ const useKeyboardNavigation = ( const handleKeyDown = useCallback( (event: KeyboardEvent) => { const rootEl = rootRef.current; - if (!rootEl) { + const virtuosoGridHandle = virtuosoRef.current; + const range = rangeRef.current; + if (!rootEl || !virtuosoGridHandle) { return; } if (imageNames.length === 0) { @@ -358,21 +288,11 @@ const useKeyboardNavigation = ( const newImageName = imageNames[newIndex]; if (newImageName) { dispatch(selectionChanged([newImageName])); - - // Only scroll if the selected item is not visible - const vis = isItemVisible(newIndex, rootEl); - if (!vis || vis === 'center') { - return; - } - virtuosoRef.current?.scrollToIndex({ - index: newIndex, - behavior: 'smooth', - align: vis, - }); + scrollIntoView(newIndex, rootEl, virtuosoGridHandle, range); } } }, - [rootRef, imageNames, currentIndex, dispatch, virtuosoRef] + [rootRef, virtuosoRef, rangeRef, imageNames, currentIndex, dispatch] ); useEffect(() => { @@ -387,16 +307,11 @@ const useKeyboardNavigation = ( export const NewGallery = memo(() => { const queryArgs = useDebouncedImageCollectionQueryArgs(); const virtuosoRef = useRef(null); + const rangeRef = useRef({ startIndex: 0, endIndex: 0 }); - // Get the ordered list of image names - this is our primary data source + // Get the ordered list of image names - this is our primary data source for virtualization const { data: imageNames = [], isLoading } = useGetImageNamesQuery(queryArgs); - // Get starred count for position calculations - const { data: counts } = useGetImageCollectionCountsQuery(queryArgs); - const starredCount = counts?.starred_count ?? 0; - - const prefetchRange = usePrefetchRanges(starredCount, queryArgs); - // Reset scroll position when query parameters change useEffect(() => { if (virtuosoRef.current && imageNames.length > 0) { @@ -407,7 +322,7 @@ export const NewGallery = memo(() => { const rootRef = useRef(null); // Enable keyboard navigation - useKeyboardNavigation(imageNames, virtuosoRef, rootRef); + useKeyboardNavigation(imageNames, virtuosoRef, rootRef, rangeRef); const [scroller, setScroller] = useState(null); const [initialize, osInstance] = useOverlayScrollbars({ @@ -439,24 +354,25 @@ export const NewGallery = memo(() => { }; }, [scroller, initialize, osInstance]); - // Handle range changes to prefetch data for visible + buffer areas - const handleRangeChanged = useCallback( - (range: ListRange) => { - prefetchRange(range.startIndex, range.endIndex); - }, - [prefetchRange] - ); + // Handle range changes - RTK Query will automatically cache and manage loading + const handleRangeChanged = useCallback((range: ListRange) => { + rangeRef.current = range; + }, []); const context = useMemo( () => ({ imageNames, queryArgs, - starredCount, }) satisfies GridContext, - [imageNames, queryArgs, starredCount] + [imageNames, queryArgs] ); + // Item content function + const itemContent: GridItemContent = useCallback((index, imageName, ctx) => { + return ; + }, []); + if (isLoading) { return ( diff --git a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts index 3e132a2d72..ae9affafdc 100644 --- a/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts +++ b/invokeai/frontend/web/src/features/gallery/store/gallerySelectors.ts @@ -45,6 +45,7 @@ export const selectImageCollectionQueryArgs = createMemoizedSelector(selectGalle search_term: gallery.searchTerm || undefined, order_dir: gallery.orderDir as SQLiteDirection, is_intermediate: false, + starred_first: true, })); export const selectAutoAssignBoardOnClick = createSelector( selectGallerySlice, diff --git a/invokeai/frontend/web/src/services/api/endpoints/images.ts b/invokeai/frontend/web/src/services/api/endpoints/images.ts index edf8786e32..aa72173c7f 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/images.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/images.ts @@ -50,10 +50,10 @@ export const imagesApi = api.injectEndpoints({ url: getListImagesUrl(queryArgs), method: 'GET', }), - providesTags: (result, error, { board_id, categories }) => { + providesTags: (result, error, queryArgs) => { return [ // Make the tags the same as the cache key - { type: 'ImageList', id: getListImagesUrl({ board_id, categories }) }, + { type: 'ImageList', id: JSON.stringify(queryArgs) }, 'FetchOnReconnect', ]; }, @@ -493,6 +493,45 @@ export const imagesApi = api.injectEndpoints({ }), providesTags: ['ImageNameList', 'FetchOnReconnect'], }), + /** + * Get paginated images with starred first (unified list) + */ + getUnifiedImageList: build.query< + ListImagesResponse, + { + offset?: number; + limit?: number; + image_origin?: 'internal' | 'external' | null; + categories?: ImageCategory[] | null; + is_intermediate?: boolean | null; + board_id?: string | null; + search_term?: string | null; + order_dir?: SQLiteDirection; + } + >({ + query: (queryArgs) => ({ + url: getListImagesUrl({ ...queryArgs, starred_first: true }), + method: 'GET', + }), + providesTags: (result, error, { board_id, categories }) => [ + { type: 'ImageList', id: getListImagesUrl({ board_id, categories }) }, + 'FetchOnReconnect', + ], + async onQueryStarted(_, { dispatch, queryFulfilled }) { + // Populate the getImageDTO cache with these images + const res = await queryFulfilled; + const imageDTOs = res.data.items; + const updates: Param0 = []; + for (const imageDTO of imageDTOs) { + updates.push({ + endpointName: 'getImageDTO', + arg: imageDTO.image_name, + value: imageDTO, + }); + } + dispatch(imagesApi.util.upsertQueryEntries(updates)); + }, + }), }), }); @@ -518,6 +557,7 @@ export const { useGetImageCollectionQuery, useLazyGetImageCollectionQuery, useGetImageNamesQuery, + useGetUnifiedImageListQuery, } = imagesApi; /**