mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): consolidated gallery (wip)
This commit is contained in:
@@ -122,7 +122,7 @@ export const GalleryPanel = memo(() => {
|
||||
</Collapse>
|
||||
<Divider pt={2} />
|
||||
<Flex w="full" h="full" pt={2}>
|
||||
{galleryView === 'images' ? <NewGallery /> : galleryView === 'videos' ? <VideoGallery /> : <NewGallery />}
|
||||
<NewGallery />
|
||||
</Flex>
|
||||
</Flex>
|
||||
);
|
||||
|
||||
@@ -4,7 +4,7 @@ import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { getFocusedRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
|
||||
import type { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import type { selectGetImageNamesQueryArgs, selectGetVideoIdsQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectGalleryView,
|
||||
@@ -28,23 +28,30 @@ import type {
|
||||
} from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { imagesApi, useImageDTO, useStarImagesMutation, useUnstarImagesMutation } from 'services/api/endpoints/images';
|
||||
import { videosApi } from 'services/api/endpoints/videos';
|
||||
import { useStarVideosMutation, useUnstarVideosMutation, useVideoDTO, videosApi } from 'services/api/endpoints/videos';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
|
||||
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
|
||||
import { useGalleryImageNames } from './use-gallery-image-names';
|
||||
import { useGalleryVideoIds } from './use-gallery-video-ids';
|
||||
import { GalleryVideo } from './ImageGrid/GalleryVideo';
|
||||
import { useGalleryImageNames, useGalleryVideoIds } from './use-gallery-image-names';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
type ListImageNamesQueryArgs = ReturnType<typeof selectGetImageNamesQueryArgs>;
|
||||
type ListVideoIdsQueryArgs = ReturnType<typeof selectGetVideoIdsQueryArgs>;
|
||||
|
||||
type GridContext = {
|
||||
queryArgs: ListImageNamesQueryArgs;
|
||||
imageNames: string[];
|
||||
};
|
||||
type GridContext =
|
||||
| {
|
||||
queryArgs: ListImageNamesQueryArgs;
|
||||
galleryView: 'images' | 'assets';
|
||||
itemIds: string[];
|
||||
}
|
||||
| {
|
||||
queryArgs: ListVideoIdsQueryArgs;
|
||||
galleryView: 'videos';
|
||||
itemIds: string[];
|
||||
};
|
||||
|
||||
const ImageAtPosition = memo(({ imageName }: { index: number; imageName: string }) => {
|
||||
/*
|
||||
@@ -96,8 +103,8 @@ const VideoAtPosition = memo(({ itemId }: { index: number; itemId: string }) =>
|
||||
});
|
||||
VideoAtPosition.displayName = 'VideoAtPosition';
|
||||
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, id, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${id ?? index}`;
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -106,7 +113,7 @@ const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageNam
|
||||
* TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value
|
||||
* changes. Cache this calculation.
|
||||
*/
|
||||
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
const getItemsPerRow = (rootEl: HTMLDivElement): number => {
|
||||
// Start from root and find virtuoso grid elements
|
||||
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
|
||||
|
||||
@@ -140,20 +147,20 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
*
|
||||
* Instead, we use a more robust approach that iteratively calculates how many images fit in the row.
|
||||
*/
|
||||
let imagesPerRow = 0;
|
||||
let itemsPerRow = 0;
|
||||
let spaceUsed = 0;
|
||||
|
||||
// Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes
|
||||
// this, without the possibility of accidentally adding an extra column.
|
||||
while (spaceUsed + itemRect.width <= containerRect.width + 1) {
|
||||
imagesPerRow++; // Increment the number of images
|
||||
itemsPerRow++; // Increment the number of images
|
||||
spaceUsed += itemRect.width; // Add image size to the used space
|
||||
if (spaceUsed + gap <= containerRect.width) {
|
||||
spaceUsed += gap; // Add gap size to the used space after each image except after the last image
|
||||
}
|
||||
}
|
||||
|
||||
return Math.max(1, imagesPerRow);
|
||||
return Math.max(1, itemsPerRow);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -180,9 +187,7 @@ const scrollIntoView = (
|
||||
return;
|
||||
}
|
||||
|
||||
const targetItem = rootEl.querySelector(
|
||||
`.virtuoso-grid-item:has([data-item-id="${targetItemId}"])`
|
||||
) as HTMLElement;
|
||||
const targetItem = rootEl.querySelector(`.virtuoso-grid-item:has([data-item-id="${targetItemId}"])`) as HTMLElement;
|
||||
|
||||
if (!targetItem) {
|
||||
if (targetIndex > range.endIndex) {
|
||||
@@ -268,11 +273,11 @@ const scrollIntoView = (
|
||||
* If the image name is not found, return 0.
|
||||
* If no image name is provided, return 0.
|
||||
*/
|
||||
const getImageIndex = (imageName: string | undefined | null, imageNames: string[]) => {
|
||||
if (!imageName || imageNames.length === 0) {
|
||||
const getItemIndex = (targetItemId: string | undefined | null, itemIds: string[]) => {
|
||||
if (!targetItemId || itemIds.length === 0) {
|
||||
return 0;
|
||||
}
|
||||
const index = imageNames.findIndex((n) => n === imageName);
|
||||
const index = itemIds.findIndex((n) => n === targetItemId);
|
||||
return index >= 0 ? index : 0;
|
||||
};
|
||||
|
||||
@@ -280,7 +285,7 @@ const getImageIndex = (imageName: string | undefined | null, imageNames: string[
|
||||
* Handles keyboard navigation for the gallery.
|
||||
*/
|
||||
const useKeyboardNavigation = (
|
||||
imageNames: string[],
|
||||
itemIds: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
rootRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
@@ -308,13 +313,13 @@ const useKeyboardNavigation = (
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
if (itemIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const imagesPerRow = getImagesPerRow(rootEl);
|
||||
const itemsPerRow = getItemsPerRow(rootEl);
|
||||
|
||||
if (imagesPerRow === 0) {
|
||||
if (itemsPerRow === 0) {
|
||||
// This can happen if the grid is not yet rendered or has no items
|
||||
return;
|
||||
}
|
||||
@@ -322,13 +327,14 @@ const useKeyboardNavigation = (
|
||||
event.preventDefault();
|
||||
|
||||
const state = getState();
|
||||
const imageName = event.altKey
|
||||
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
|
||||
// we start from the last selected image
|
||||
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
|
||||
: selectLastSelectedImage(state);
|
||||
const imageName =
|
||||
event.altKey && selectGalleryView(state) !== 'videos'
|
||||
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
|
||||
// we start from the last selected image
|
||||
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
|
||||
: selectLastSelectedImage(state);
|
||||
|
||||
const currentIndex = getImageIndex(imageName, imageNames);
|
||||
const currentIndex = getItemIndex(imageName, itemIds);
|
||||
|
||||
let newIndex = currentIndex;
|
||||
|
||||
@@ -342,7 +348,7 @@ const useKeyboardNavigation = (
|
||||
}
|
||||
break;
|
||||
case 'ArrowRight':
|
||||
if (currentIndex < imageNames.length - 1) {
|
||||
if (currentIndex < itemIds.length - 1) {
|
||||
newIndex = currentIndex + 1;
|
||||
// } else {
|
||||
// // Wrap to first image
|
||||
@@ -351,26 +357,26 @@ const useKeyboardNavigation = (
|
||||
break;
|
||||
case 'ArrowUp':
|
||||
// If on first row, stay on current image
|
||||
if (currentIndex < imagesPerRow) {
|
||||
if (currentIndex < itemsPerRow) {
|
||||
newIndex = currentIndex;
|
||||
} else {
|
||||
newIndex = Math.max(0, currentIndex - imagesPerRow);
|
||||
newIndex = Math.max(0, currentIndex - itemsPerRow);
|
||||
}
|
||||
break;
|
||||
case 'ArrowDown':
|
||||
// If no images below, stay on current image
|
||||
if (currentIndex >= imageNames.length - imagesPerRow) {
|
||||
if (currentIndex >= itemIds.length - itemsPerRow) {
|
||||
newIndex = currentIndex;
|
||||
} else {
|
||||
newIndex = Math.min(imageNames.length - 1, currentIndex + imagesPerRow);
|
||||
newIndex = Math.min(itemIds.length - 1, currentIndex + itemsPerRow);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < imageNames.length) {
|
||||
const newImageName = imageNames[newIndex];
|
||||
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < itemIds.length) {
|
||||
const newImageName = itemIds[newIndex];
|
||||
if (newImageName) {
|
||||
if (event.altKey) {
|
||||
if (selectGalleryView(state) !== 'videos' && event.altKey) {
|
||||
dispatch(imageToCompareChanged(newImageName));
|
||||
} else {
|
||||
dispatch(selectionChanged([newImageName]));
|
||||
@@ -378,7 +384,7 @@ const useKeyboardNavigation = (
|
||||
}
|
||||
}
|
||||
},
|
||||
[rootRef, virtuosoRef, imageNames, getState, dispatch]
|
||||
[rootRef, virtuosoRef, itemIds, getState, dispatch]
|
||||
);
|
||||
|
||||
useRegisteredHotkeys({
|
||||
@@ -451,8 +457,8 @@ const useKeyboardNavigation = (
|
||||
* This is useful for keyboard navigation and ensuring the user can see their selection.
|
||||
* It only tracks the last selected image, not the image to compare.
|
||||
*/
|
||||
const useKeepSelectedImageInView = (
|
||||
imageNames: string[],
|
||||
const useKeepSelectedItemInView = (
|
||||
itemIds: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
rootRef: React.RefObject<HTMLDivElement>,
|
||||
rangeRef: MutableRefObject<ListRange>
|
||||
@@ -460,19 +466,19 @@ const useKeepSelectedImageInView = (
|
||||
const selection = useAppSelector(selectSelection);
|
||||
|
||||
useEffect(() => {
|
||||
const targetImageName = selection.at(-1);
|
||||
const targetItemId = selection.at(-1);
|
||||
const virtuosoGridHandle = virtuosoRef.current;
|
||||
const rootEl = rootRef.current;
|
||||
const range = rangeRef.current;
|
||||
|
||||
if (!virtuosoGridHandle || !rootEl || !targetImageName || !imageNames || imageNames.length === 0) {
|
||||
if (!virtuosoGridHandle || !rootEl || !targetItemId || !itemIds || itemIds.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
scrollIntoView(targetImageName, imageNames, rootEl, virtuosoGridHandle, range);
|
||||
scrollIntoView(targetItemId, itemIds, rootEl, virtuosoGridHandle, range);
|
||||
}, 0);
|
||||
}, [imageNames, rangeRef, rootRef, virtuosoRef, selection]);
|
||||
}, [itemIds, rangeRef, rootRef, virtuosoRef, selection]);
|
||||
};
|
||||
|
||||
/**
|
||||
@@ -523,22 +529,32 @@ const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const useStarImageHotkey = () => {
|
||||
const lastSelectedImage = useAppSelector(selectLastSelectedImage);
|
||||
const selectionCount = useAppSelector(selectSelectionCount);
|
||||
const galleryView = useAppSelector(selectGalleryView);
|
||||
const isGalleryFocused = useIsRegionFocused('gallery');
|
||||
const imageDTO = useImageDTO(lastSelectedImage);
|
||||
const imageDTO = useImageDTO(galleryView !== 'videos' ? lastSelectedImage : null);
|
||||
const videoDTO = useVideoDTO(galleryView === 'videos' ? lastSelectedImage : null);
|
||||
const [starImages] = useStarImagesMutation();
|
||||
const [unstarImages] = useUnstarImagesMutation();
|
||||
|
||||
const [starVideos] = useStarVideosMutation();
|
||||
const [unstarVideos] = useUnstarVideosMutation();
|
||||
|
||||
const handleStarHotkey = useCallback(() => {
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
if (!isGalleryFocused) {
|
||||
return;
|
||||
}
|
||||
if (imageDTO.starred) {
|
||||
unstarImages({ image_names: [imageDTO.image_name] });
|
||||
} else {
|
||||
starImages({ image_names: [imageDTO.image_name] });
|
||||
if (galleryView === 'videos' && videoDTO) {
|
||||
if (videoDTO.starred) {
|
||||
unstarVideos({ video_ids: [videoDTO.video_id] });
|
||||
} else {
|
||||
starVideos({ video_ids: [videoDTO.video_id] });
|
||||
}
|
||||
} else if (galleryView !== 'videos' && imageDTO) {
|
||||
if (imageDTO.starred) {
|
||||
unstarImages({ image_names: [imageDTO.image_name] });
|
||||
} else {
|
||||
starImages({ image_names: [imageDTO.image_name] });
|
||||
}
|
||||
}
|
||||
}, [imageDTO, isGalleryFocused, starImages, unstarImages]);
|
||||
|
||||
@@ -546,7 +562,12 @@ const useStarImageHotkey = () => {
|
||||
id: 'starImage',
|
||||
category: 'gallery',
|
||||
callback: handleStarHotkey,
|
||||
options: { enabled: !!imageDTO && selectionCount === 1 && isGalleryFocused },
|
||||
options: {
|
||||
enabled:
|
||||
((galleryView === 'videos' && !!videoDTO) || (galleryView !== 'videos' && !!imageDTO)) &&
|
||||
selectionCount === 1 &&
|
||||
isGalleryFocused,
|
||||
},
|
||||
dependencies: [imageDTO, selectionCount, isGalleryFocused, handleStarHotkey],
|
||||
});
|
||||
};
|
||||
@@ -558,18 +579,22 @@ export const NewGallery = memo(() => {
|
||||
const galleryView = useAppSelector(selectGalleryView);
|
||||
|
||||
// Get the ordered list of image names - this is our primary data source for virtualization
|
||||
const { queryArgs, imageNames, isLoading } = useGalleryImageNames();
|
||||
const { queryArgs: videoQueryArgs, videoIds, isLoading: isLoadingVideos } = useGalleryVideoIds();
|
||||
const galleryImageNamesQuery = useGalleryImageNames();
|
||||
const galleryVideoIdsQuery = useGalleryVideoIds();
|
||||
|
||||
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
|
||||
const { onRangeChanged } = useRangeBasedImageFetching({
|
||||
imageNames,
|
||||
enabled: !isLoading,
|
||||
imageNames: galleryImageNamesQuery.imageNames,
|
||||
enabled: !galleryImageNamesQuery.isLoading,
|
||||
});
|
||||
|
||||
const itemIds = galleryView === 'videos' ? galleryVideoIdsQuery.video_ids : galleryImageNamesQuery.imageNames;
|
||||
const queryArgs = galleryView === 'videos' ? galleryVideoIdsQuery.queryArgs : galleryImageNamesQuery.queryArgs;
|
||||
const isLoading = galleryView === 'videos' ? galleryVideoIdsQuery.isLoading : galleryImageNamesQuery.isLoading;
|
||||
useStarImageHotkey();
|
||||
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
|
||||
|
||||
useKeepSelectedItemInView(itemIds, virtuosoRef, rootRef, rangeRef);
|
||||
useKeyboardNavigation(itemIds, virtuosoRef, rootRef);
|
||||
const scrollerRef = useScrollableGallery(rootRef);
|
||||
|
||||
/*
|
||||
@@ -584,7 +609,7 @@ export const NewGallery = memo(() => {
|
||||
[onRangeChanged]
|
||||
);
|
||||
|
||||
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs, videoIds, videoQueryArgs }), [imageNames, queryArgs, videoIds, videoQueryArgs]);
|
||||
const context = useMemo<GridContext>(() => ({ itemIds, galleryView, queryArgs }), [itemIds, queryArgs, galleryView]);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
@@ -595,7 +620,7 @@ export const NewGallery = memo(() => {
|
||||
);
|
||||
}
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
if (itemIds.length === 0) {
|
||||
return (
|
||||
<Flex w="full" h="full" alignItems="center" justifyContent="center">
|
||||
<Text color="base.300">No images found</Text>
|
||||
@@ -609,7 +634,7 @@ export const NewGallery = memo(() => {
|
||||
<VirtuosoGrid<string, GridContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={galleryView === 'images' ? imageNames : videoIds}
|
||||
data={itemIds}
|
||||
increaseViewportBy={4096}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
@@ -652,8 +677,12 @@ const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context
|
||||
});
|
||||
ListComponent.displayName = 'ListComponent';
|
||||
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
const itemContent: GridItemContent<string, GridContext> = (index, itemId, { galleryView }) => {
|
||||
if (galleryView === 'videos') {
|
||||
return <VideoAtPosition index={index} itemId={itemId} />;
|
||||
} else {
|
||||
return <ImageAtPosition index={index} imageName={itemId} />;
|
||||
}
|
||||
};
|
||||
|
||||
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
import { EMPTY_ARRAY } from 'app/store/constants';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
selectGalleryView,
|
||||
selectGetImageNamesQueryArgs,
|
||||
selectGetVideoIdsQueryArgs,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { useGetImageNamesQuery } from 'services/api/endpoints/images';
|
||||
import { useGetVideoIdsQuery } from 'services/api/endpoints/videos';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const getImageNamesQueryOptions = {
|
||||
@@ -13,9 +20,33 @@ const getImageNamesQueryOptions = {
|
||||
}),
|
||||
} satisfies Parameters<typeof useGetImageNamesQuery>[1];
|
||||
|
||||
const getVideoIdsQueryOptions = {
|
||||
refetchOnReconnect: true,
|
||||
selectFromResult: ({ currentData, isLoading, isFetching }) => ({
|
||||
video_ids: currentData?.video_ids ?? EMPTY_ARRAY,
|
||||
isLoading,
|
||||
isFetching,
|
||||
}),
|
||||
} satisfies Parameters<typeof useGetVideoIdsQuery>[1];
|
||||
|
||||
export const useGalleryImageNames = () => {
|
||||
const galleryView = useAppSelector(selectGalleryView);
|
||||
const _queryArgs = useAppSelector(selectGetImageNamesQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, 300);
|
||||
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions);
|
||||
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(
|
||||
galleryView !== 'videos' ? queryArgs : skipToken,
|
||||
getImageNamesQueryOptions
|
||||
);
|
||||
return { imageNames, isLoading, isFetching, queryArgs };
|
||||
};
|
||||
|
||||
export const useGalleryVideoIds = () => {
|
||||
const galleryView = useAppSelector(selectGalleryView);
|
||||
const _queryArgs = useAppSelector(selectGetVideoIdsQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, 300);
|
||||
const { video_ids, isLoading, isFetching } = useGetVideoIdsQuery(
|
||||
galleryView === 'videos' ? queryArgs : skipToken,
|
||||
getVideoIdsQueryOptions
|
||||
);
|
||||
return { video_ids, isLoading, isFetching, queryArgs };
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ import stableHash from 'stable-hash';
|
||||
import type { Param0 } from 'tsafe';
|
||||
|
||||
import { api, buildV1Url, LIST_TAG } from '..';
|
||||
import { skipToken } from '@reduxjs/toolkit/query';
|
||||
|
||||
/**
|
||||
* Builds an endpoint URL for the videos router
|
||||
@@ -224,3 +225,8 @@ export const getVideoDTOSafe = async (
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
export const useVideoDTO = (id: string | null | undefined) => {
|
||||
const { currentData: videoDTO } = useGetVideoDTOQuery(id ?? skipToken);
|
||||
return videoDTO ?? null;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user