integrating video into gallery - thinking maybe a new category of image would make more senes

This commit is contained in:
Mary Hipp
2025-08-14 16:17:00 -04:00
committed by psychedelicious
parent d0e12c31f7
commit 66d8b86149
14 changed files with 1138 additions and 16 deletions

View File

@@ -411,7 +411,9 @@
"openViewer": "Open Viewer",
"closeViewer": "Close Viewer",
"move": "Move",
"useForPromptGeneration": "Use for Prompt Generation"
"useForPromptGeneration": "Use for Prompt Generation",
"videos": "Videos",
"videosTab": "Videos you've created and saved within Invoke."
},
"hotkeys": {
"hotkeys": "Hotkeys",

View File

@@ -18,6 +18,7 @@ import { GallerySettingsPopover } from './GallerySettingsPopover/GallerySettings
import { GalleryUploadButton } from './GalleryUploadButton';
import { GallerySearch } from './ImageGrid/GallerySearch';
import { NewGallery } from './NewGallery';
import { VideoGallery } from './VideoGallery';
const COLLAPSE_STYLES: CSSProperties = { flexShrink: 0, minHeight: 0, width: '100%' };
@@ -42,6 +43,10 @@ export const GalleryPanel = memo(() => {
dispatch(galleryViewChanged('assets'));
}, [dispatch]);
const handleClickVideos = useCallback(() => {
dispatch(galleryViewChanged('videos'));
}, [dispatch]);
const handleClickSearch = useCallback(() => {
onResetSearchTerm();
if (!searchDisclosure.isOpen && galleryPanel.$isCollapsed.get()) {
@@ -83,6 +88,14 @@ export const GalleryPanel = memo(() => {
>
{t('gallery.assets')}
</Button>
<Button
tooltip={t('gallery.videosTab')}
onClick={handleClickVideos}
data-testid="videos-tab"
colorScheme={galleryView === 'videos' ? 'invokeBlue' : undefined}
>
{t('gallery.videos')}
</Button>
</ButtonGroup>
<Flex flexGrow={1} flexBasis={0} justifyContent="flex-end">
<GalleryUploadButton />
@@ -109,7 +122,7 @@ export const GalleryPanel = memo(() => {
</Collapse>
<Divider pt={2} />
<Flex w="full" h="full" pt={2}>
<NewGallery />
{galleryView === 'images' ? <NewGallery /> : galleryView === 'videos' ? <VideoGallery /> : <NewGallery />}
</Flex>
</Flex>
);

View File

@@ -0,0 +1,272 @@
import { combine } from '@atlaskit/pragmatic-drag-and-drop/combine';
import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop/element/adapter';
import type { FlexProps, SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { uniq } from 'es-toolkit';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPreviewMultipleImage';
import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage';
import type { DndDragPreviewSingleImageState } from 'features/dnd/DndDragPreviewSingleImage';
import { createSingleImageDragPreview, setSingleImageDragPreview } from 'features/dnd/DndDragPreviewSingleImage';
import { firefoxDndFix } from 'features/dnd/util';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons';
import {
selectGetImageNamesQueryArgs,
selectGetVideoIdsQueryArgs,
selectSelectedBoardId,
selectSelection,
} from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'features/gallery/store/gallerySlice';
import { navigationApi } from 'features/ui/layouts/navigation-api';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import type { MouseEvent, MouseEventHandler } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { PiImageBold } from 'react-icons/pi';
import { imagesApi } from 'services/api/endpoints/images';
import { videosApi } from 'services/api/endpoints/videos';
import type { ImageDTO, VideoDTO } from 'services/api/types';
const galleryImageContainerSX = {
containerType: 'inline-size',
w: 'full',
h: 'full',
'.gallery-image-size-badge': {
'@container (max-width: 80px)': {
'&': { display: 'none' },
},
},
'&[data-is-dragging=true]': {
opacity: 0.3,
},
userSelect: 'none',
webkitUserSelect: 'none',
position: 'relative',
justifyContent: 'center',
alignItems: 'center',
aspectRatio: '1/1',
'::before': {
content: '""',
display: 'inline-block',
position: 'absolute',
top: 0,
left: 0,
right: 0,
bottom: 0,
pointerEvents: 'none',
borderRadius: 'base',
},
'&[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-500), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-300), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
'&:hover::before': {
boxShadow:
'inset 0px 0px 0px 1px var(--invoke-colors-invokeBlue-300), inset 0px 0px 0px 2px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeBlue-400), inset 0px 0px 0px 4px var(--invoke-colors-invokeBlue-800)',
},
'&:hover[data-selected-for-compare=true]::before': {
boxShadow:
'inset 0px 0px 0px 3px var(--invoke-colors-invokeGreen-200), inset 0px 0px 0px 4px var(--invoke-colors-invokeGreen-800)',
},
} satisfies SystemStyleObject;
interface Props {
videoDTO: VideoDTO;
}
const buildOnClick =
(videoId: string, dispatch: AppDispatch, getState: AppGetState) => (e: MouseEvent<HTMLDivElement>) => {
const { shiftKey, ctrlKey, metaKey, altKey } = e;
const state = getState();
const queryArgs = selectGetVideoIdsQueryArgs(state);
const videoIds = videosApi.endpoints.getVideoIds.select(queryArgs)(state).data?.video_ids ?? [];
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded
if (videoIds.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([videoId]));
}
return;
}
const selection = state.gallery.selection;
if (shiftKey) {
const rangeEndVideoId = videoId;
const lastSelectedVideoId = selection.at(-1);
const lastClickedIndex = videoIds.findIndex((id) => id === lastSelectedVideoId);
const currentClickedIndex = videoIds.findIndex((id) => id === rangeEndVideoId);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const videosToSelect = videoIds.slice(start, end + 1);
dispatch(selectionChanged(uniq(selection.concat(videosToSelect))));
}
} else if (ctrlKey || metaKey) {
if (selection.some((n) => n === videoId) && selection.length > 1) {
dispatch(selectionChanged(uniq(selection.filter((n) => n !== videoId))));
} else {
dispatch(selectionChanged(uniq(selection.concat(videoId))));
}
} else {
dispatch(selectionChanged([videoId]));
}
};
export const GalleryVideo = memo(({ videoDTO }: Props) => {
const store = useAppStore();
const [isDragging, setIsDragging] = useState(false);
const [dragPreviewState, setDragPreviewState] = useState<
DndDragPreviewSingleImageState | DndDragPreviewMultipleImageState | null
>(null);
const ref = useRef<HTMLDivElement>(null);
const selectIsSelected = useMemo(
() => createSelector(selectGallerySlice, (gallery) => gallery.selection.includes(videoDTO.video_id)),
[videoDTO.video_id]
);
const isSelected = useAppSelector(selectIsSelected);
useEffect(() => {
const element = ref.current;
if (!element) {
return;
}
return combine(
firefoxDndFix(element),
draggable({
element,
// getInitialData: () => {
// const selection = selectSelection(store.getState());
// const boardId = selectSelectedBoardId(store.getState());
// // When we have multiple images selected, and the dragged image is part of the selection, initiate a
// // multi-image drag.
// if (selection.length > 1 && selection.includes(videoDTO.video_id)) {
// return multipleImageDndSource.getData({
// image_names: selection,
// board_id: boardId,
// });
// }
// // Otherwise, initiate a single-image drag
// return singleImageDndSource.getData({ videoDTO }, videoDTO.video_id);
// },
// This is a "local" drag start event, meaning that it is only called when this specific image is dragged.
onDragStart: ({ source }) => {
// When we start dragging a single image, set the dragging state to true. This is only called when this
// specific image is dragged.
if (singleImageDndSource.typeGuard(source.data)) {
setIsDragging(true);
return;
}
},
onGenerateDragPreview: (args) => {
if (multipleImageDndSource.typeGuard(args.source.data)) {
setMultipleImageDragPreview({
multipleImageDndData: args.source.data,
onGenerateDragPreviewArgs: args,
setDragPreviewState,
});
} else if (singleImageDndSource.typeGuard(args.source.data)) {
setSingleImageDragPreview({
singleImageDndData: args.source.data,
onGenerateDragPreviewArgs: args,
setDragPreviewState,
});
}
},
}),
// monitorForElements({
// // This is a "global" drag start event, meaning that it is called for all drag events.
// onDragStart: ({ source }) => {
// // When we start dragging multiple images, set the dragging state to true if the dragged image is part of the
// // selection. This is called for all drag events.
// if (
// multipleImageDndSource.typeGuard(source.data) &&
// source.data.payload.video_ids.includes(videoDTO.video_id)
// ) {
// setIsDragging(true);
// }
// },
// onDrop: () => {
// // Always set the dragging state to false when a drop event occurs.
// setIsDragging(false);
// },
// })
);
}, [videoDTO, store]);
const [isHovered, setIsHovered] = useState(false);
const onMouseOver = useCallback(() => {
setIsHovered(true);
}, []);
const onMouseOut = useCallback(() => {
setIsHovered(false);
}, []);
const onClick = useMemo(() => buildOnClick(videoDTO.video_id, store.dispatch, store.getState), [videoDTO, store]);
const onDoubleClick = useCallback<MouseEventHandler<HTMLDivElement>>(() => {
store.dispatch(imageToCompareChanged(null));
navigationApi.focusPanelInActiveTab(VIEWER_PANEL_ID);
}, [store]);
// useImageContextMenu(videoDTO, ref);
return (
<>
<Flex
ref={ref}
sx={galleryImageContainerSX}
data-is-dragging={isDragging}
data-video-id={videoDTO.video_id}
role="button"
onMouseOver={onMouseOver}
onMouseOut={onMouseOut}
onClick={onClick}
onDoubleClick={onDoubleClick}
data-selected={isSelected}
>
<Image
pointerEvents="none"
src={videoDTO.thumbnail_url}
w={videoDTO.width}
fallback={<GalleryVideoPlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
borderRadius="base"
/>
{/* <GalleryImageHoverIcons videoDTO={videoDTO} isHovered={isHovered} /> */}
</Flex>
{dragPreviewState?.type === 'multiple-image' ? createMultipleImageDragPreview(dragPreviewState) : null}
{dragPreviewState?.type === 'single-image' ? createSingleImageDragPreview(dragPreviewState) : null}
</>
);
});
GalleryVideo.displayName = 'GalleryVideo';
export const GalleryVideoPlaceholder = memo((props: FlexProps) => (
<Flex w="full" h="full" bg="base.850" borderRadius="base" alignItems="center" justifyContent="center" {...props}>
<Icon as={PiImageBold} boxSize={16} color="base.800" />
</Flex>
));
GalleryVideoPlaceholder.displayName = 'GalleryVideoPlaceholder';

View File

@@ -7,6 +7,7 @@ import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBased
import type { selectGetImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
selectGalleryImageMinimumWidth,
selectGalleryView,
selectImageToCompare,
selectLastSelectedImage,
selectSelection,
@@ -32,6 +33,7 @@ import { useDebounce } from 'use-debounce';
import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
import { useGalleryImageNames } from './use-gallery-image-names';
import { useGalleryVideoIds } from './use-gallery-video-ids';
const log = logger('gallery');
@@ -526,9 +528,11 @@ export const NewGallery = memo(() => {
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const galleryView = useAppSelector(selectGalleryView);
// Get the ordered list of image names - this is our primary data source for virtualization
const { queryArgs, imageNames, isLoading } = useGalleryImageNames();
const { queryArgs: videoQueryArgs, videoIds, isLoading: isLoadingVideos } = useGalleryVideoIds();
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
const { onRangeChanged } = useRangeBasedImageFetching({
@@ -553,7 +557,7 @@ export const NewGallery = memo(() => {
[onRangeChanged]
);
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs, videoIds, videoQueryArgs }), [imageNames, queryArgs, videoIds, videoQueryArgs]);
if (isLoading) {
return (
@@ -578,7 +582,7 @@ export const NewGallery = memo(() => {
<VirtuosoGrid<string, GridContext>
ref={virtuosoRef}
context={context}
data={imageNames}
data={galleryView === 'images' ? imageNames : videoIds}
increaseViewportBy={4096}
itemContent={itemContent}
computeItemKey={computeItemKey}

View File

@@ -0,0 +1,578 @@
import { Box, Flex, forwardRef, Grid, GridItem, Spinner, Text } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { getFocusedRegion, useIsRegionFocused } from 'common/hooks/focus';
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
import type { selectGetImageNamesQueryArgs, selectGetVideoIdsQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
selectGalleryImageMinimumWidth,
selectGalleryView,
selectImageToCompare,
selectLastSelectedImage,
selectSelectionCount,
} from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { MutableRefObject, RefObject } from 'react';
import React, { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import type {
GridComponents,
GridComputeItemKey,
GridItemContent,
ListRange,
ScrollSeekConfiguration,
VirtuosoGridHandle,
} from 'react-virtuoso';
import { VirtuosoGrid } from 'react-virtuoso';
import { useDebounce } from 'use-debounce';
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
import { useGalleryImageNames } from './use-gallery-image-names';
import { useGalleryVideoIds } from './use-gallery-video-ids';
import { videosApi } from 'services/api/endpoints/videos';
import { GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
import { useRangeBasedVideoFetching } from '../hooks/useRangeBasedVideoFetching';
import { GalleryVideo } from './ImageGrid/GalleryVideo';
const log = logger('gallery');
type ListVideoIdsQueryArgs = ReturnType<typeof selectGetVideoIdsQueryArgs>;
type GridContext = {
queryArgs: ListVideoIdsQueryArgs;
videoIds: string[];
};
const VideoAtPosition = memo(({ videoId }: { index: number; videoId: string }) => {
/*
* We rely on the useRangeBasedImageFetching to fetch all image DTOs, caching them with RTK Query.
*
* In this component, we just want to consume that cache. Unforutnately, RTK Query does not provide a way to
* subscribe to a query without triggering a new fetch.
*
* There is a hack, though:
* - https://github.com/reduxjs/redux-toolkit/discussions/4213
*
* This essentially means "subscribe to the query once it has some data".
*/
// Use `currentData` instead of `data` to prevent a flash of previous image rendered at this index
const { currentData: videoDTO, isUninitialized } = videosApi.endpoints.getVideoDTO.useQueryState(videoId);
videosApi.endpoints.getVideoDTO.useQuerySubscription(videoId, { skip: isUninitialized });
if (!videoDTO) {
return <GalleryImagePlaceholder data-video-id={videoId} />;
}
return <GalleryVideo videoDTO={videoDTO} />;
});
VideoAtPosition.displayName = 'VideoAtPosition';
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
};
/**
* Calculate how many images fit in a row based on the current grid layout.
*
* TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value
* changes. Cache this calculation.
*/
const getVideosPerRow = (rootEl: HTMLDivElement): number => {
// Start from root and find virtuoso grid elements
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
if (!gridElement) {
return 0;
}
const firstGridItem = gridElement.querySelector('.virtuoso-grid-item');
if (!firstGridItem) {
return 0;
}
const itemRect = firstGridItem.getBoundingClientRect();
const containerRect = gridElement.getBoundingClientRect();
// Get the computed gap from CSS
const gridStyle = window.getComputedStyle(gridElement);
const gapValue = gridStyle.gap;
const gap = parseFloat(gapValue);
if (isNaN(gap) || !itemRect.width || !itemRect.height || !containerRect.width || !containerRect.height) {
return 0;
}
/**
* You might be tempted to just do some simple math like:
* const imagesPerRow = Math.floor(containerRect.width / itemRect.width);
*
* But floating point precision can cause issues with this approach, causing it to be off by 1 in some cases.
*
* Instead, we use a more robust approach that iteratively calculates how many images fit in the row.
*/
let videosPerRow = 0;
let spaceUsed = 0;
// Floating point precision can cause imagesPerRow to be 1 too small. Adding 1px to the container size fixes
// this, without the possibility of accidentally adding an extra column.
while (spaceUsed + itemRect.width <= containerRect.width + 1) {
videosPerRow++; // Increment the number of images
spaceUsed += itemRect.width; // Add image size to the used space
if (spaceUsed + gap <= containerRect.width) {
spaceUsed += gap; // Add gap size to the used space after each image except after the last image
}
}
return Math.max(1, videosPerRow);
};
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
const scrollIntoView = (
targetVideoId: string,
videoIds: string[],
rootEl: HTMLDivElement,
virtuosoGridHandle: VirtuosoGridHandle,
range: ListRange
) => {
if (range.endIndex === 0) {
// No range is rendered; no need to scroll to anything.
return;
}
const targetIndex = videoIds.findIndex((id) => id === targetVideoId);
if (targetIndex === -1) {
// The image isn't in the currently rendered list.
return;
}
const targetItem = rootEl.querySelector(
`.virtuoso-grid-item:has([data-video-id="${targetVideoId}"])`
) as HTMLElement;
if (!targetItem) {
if (targetIndex > range.endIndex) {
virtuosoGridHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else if (targetIndex < range.startIndex) {
virtuosoGridHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else {
log.debug(
`Unable to find video ${targetVideoId} at index ${targetIndex} but it is in the rendered range ${range.startIndex}-${range.endIndex}`
);
}
return;
}
// We found the image in the DOM, but it might be in the overscan range - rendered but not in the visible viewport.
// Check if it is in the viewport and scroll if necessary.
const itemRect = targetItem.getBoundingClientRect();
const rootRect = rootEl.getBoundingClientRect();
if (itemRect.top < rootRect.top) {
virtuosoGridHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'start',
});
} else if (itemRect.bottom > rootRect.bottom) {
virtuosoGridHandle.scrollToIndex({
index: targetIndex,
behavior: 'auto',
align: 'end',
});
} else {
// Image is already in view
}
return;
};
/**
* Get the index of the image in the list of image names.
* If the image name is not found, return 0.
* If no image name is provided, return 0.
*/
const getVideoIndex = (videoId: string | undefined | null, videoIds: string[]) => {
if (!videoId || videoIds.length === 0) {
return 0;
}
const index = videoIds.findIndex((n) => n === videoId);
return index >= 0 ? index : 0;
};
/**
* Handles keyboard navigation for the gallery.
*/
const useKeyboardNavigation = (
videoIds: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
rootRef: React.RefObject<HTMLDivElement>
) => {
const { dispatch, getState } = useAppStore();
const handleKeyDown = useCallback(
(event: KeyboardEvent) => {
if (getFocusedRegion() !== 'gallery') {
// Only handle keyboard navigation when the gallery is focused
return;
}
// Only handle arrow keys
if (!['ArrowUp', 'ArrowDown', 'ArrowLeft', 'ArrowRight'].includes(event.key)) {
return;
}
// Don't interfere if user is typing in an input
if (event.target instanceof HTMLInputElement || event.target instanceof HTMLTextAreaElement) {
return;
}
const rootEl = rootRef.current;
const virtuosoGridHandle = virtuosoRef.current;
if (!rootEl || !virtuosoGridHandle) {
return;
}
if (videoIds.length === 0) {
return;
}
const videosPerRow = getVideosPerRow(rootEl);
if (videosPerRow === 0) {
// This can happen if the grid is not yet rendered or has no items
return;
}
event.preventDefault();
const state = getState();
const videoId = event.altKey
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
// we start from the last selected image
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
: selectLastSelectedImage(state);
const currentIndex = getVideoIndex(videoId, videoIds);
let newIndex = currentIndex;
switch (event.key) {
case 'ArrowLeft':
if (currentIndex > 0) {
newIndex = currentIndex - 1;
// } else {
// // Wrap to last image
// newIndex = imageNames.length - 1;
}
break;
case 'ArrowRight':
if (currentIndex < videoIds.length - 1) {
newIndex = currentIndex + 1;
// } else {
// // Wrap to first image
// newIndex = 0;
}
break;
case 'ArrowUp':
// If on first row, stay on current image
if (currentIndex < videosPerRow) {
newIndex = currentIndex;
} else {
newIndex = Math.max(0, currentIndex - videosPerRow);
}
break;
case 'ArrowDown':
// If no images below, stay on current image
if (currentIndex >= videoIds.length - videosPerRow) {
newIndex = currentIndex;
} else {
newIndex = Math.min(videoIds.length - 1, currentIndex + videosPerRow);
}
break;
}
if (newIndex !== currentIndex && newIndex >= 0 && newIndex < videoIds.length) {
const newVideoId = videoIds[newIndex];
if (newVideoId) {
dispatch(selectionChanged([newVideoId]));
}
}
},
[rootRef, virtuosoRef, videoIds, getState, dispatch]
);
useRegisteredHotkeys({
id: 'galleryNavLeft',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavRight',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavUp',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavDown',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavLeftAlt',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavRightAlt',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavUpAlt',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
useRegisteredHotkeys({
id: 'galleryNavDownAlt',
category: 'gallery',
callback: handleKeyDown,
options: { preventDefault: true },
dependencies: [handleKeyDown],
});
};
/**
* Keeps the last selected image in view when the gallery is scrolled.
* This is useful for keyboard navigation and ensuring the user can see their selection.
* It only tracks the last selected image, not the image to compare.
*/
const useKeepSelectedVideoInView = (
videoIds: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
rootRef: React.RefObject<HTMLDivElement>,
rangeRef: MutableRefObject<ListRange>
) => {
const targetVideoId = useAppSelector(selectLastSelectedImage);
useEffect(() => {
const virtuosoGridHandle = virtuosoRef.current;
const rootEl = rootRef.current;
const range = rangeRef.current;
if (!virtuosoGridHandle || !rootEl || !targetVideoId || !videoIds || videoIds.length === 0) {
return;
}
scrollIntoView(targetVideoId, videoIds, rootEl, virtuosoGridHandle, range);
}, [targetVideoId, videoIds, rangeRef, rootRef, virtuosoRef]);
};
/**
* Handles the initialization of the overlay scrollbars for the gallery, returning the ref to the scroller element.
*/
const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
defer: true,
events: {
initialized(osInstance) {
// force overflow styles
const { viewport } = osInstance.elements();
viewport.style.overflowX = `var(--os-viewport-overflow-x)`;
viewport.style.overflowY = `var(--os-viewport-overflow-y)`;
},
},
options: {
scrollbars: {
visibility: 'auto',
autoHide: 'scroll',
autoHideDelay: 1300,
theme: 'os-theme-dark',
},
},
});
useEffect(() => {
const { current: root } = rootRef;
if (scroller && root) {
initialize({
target: root,
elements: {
viewport: scroller,
},
});
}
return () => {
osInstance()?.destroy();
};
}, [scroller, initialize, osInstance, rootRef]);
return scrollerRef;
};
export const VideoGallery = memo(() => {
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const galleryView = useAppSelector(selectGalleryView);
// Get the ordered list of image names - this is our primary data source for virtualization
const { queryArgs, videoIds, isLoading } = useGalleryVideoIds();
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
const { onRangeChanged } = useRangeBasedVideoFetching({
videoIds,
enabled: !isLoading,
});
useKeepSelectedVideoInView(videoIds, virtuosoRef, rootRef, rangeRef);
useKeyboardNavigation(videoIds, virtuosoRef, rootRef);
const scrollerRef = useScrollableGallery(rootRef);
/*
* We have to keep track of the visible range for keep-selected-image-in-view functionality and push the range to
* the range-based image fetching hook.
*/
const handleRangeChanged = useCallback(
(range: ListRange) => {
rangeRef.current = range;
onRangeChanged(range);
},
[onRangeChanged]
);
const context = useMemo<GridContext>(() => ({ videoIds, queryArgs }), [videoIds, queryArgs]);
if (isLoading) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center" gap={4}>
<Spinner size="lg" opacity={0.3} />
<Text color="base.300">Loading gallery...</Text>
</Flex>
);
}
if (videoIds.length === 0) {
return (
<Flex w="full" h="full" alignItems="center" justifyContent="center">
<Text color="base.300">No videos found</Text>
</Flex>
);
}
return (
// This wrapper component is necessary to initialize the overlay scrollbars!
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<VirtuosoGrid<string, GridContext>
ref={virtuosoRef}
context={context}
data={videoIds}
increaseViewportBy={4096}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
style={style}
scrollerRef={scrollerRef}
scrollSeekConfiguration={scrollSeekConfiguration}
rangeChanged={handleRangeChanged}
/>
<GallerySelectionCountTag />
</Box>
);
});
VideoGallery.displayName = 'VideoGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => {
return Math.abs(velocity) > 2048;
},
exit: (velocity) => {
return velocity === 0;
},
};
// Styles
const style = { height: '100%', width: '100%' };
const selectGridTemplateColumns = createSelector(
selectGalleryImageMinimumWidth,
(galleryImageMinimumWidth) => `repeat(auto-fill, minmax(${galleryImageMinimumWidth}px, 1fr))`
);
// Grid components
const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context: _, ...rest }, ref) => {
const _gridTemplateColumns = useAppSelector(selectGridTemplateColumns);
const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, 300);
return <Grid ref={ref} gridTemplateColumns={gridTemplateColumns} gap={1} {...rest} />;
});
ListComponent.displayName = 'ListComponent';
const itemContent: GridItemContent<string, GridContext> = (index, videoId) => {
return <VideoAtPosition index={index} videoId={videoId} />;
};
const ItemComponent: GridComponents<GridContext>['Item'] = forwardRef(({ context: _, ...rest }, ref) => (
<GridItem ref={ref} aspectRatio="1/1" {...rest} />
));
ItemComponent.displayName = 'ItemComponent';
const ScrollSeekPlaceholderComponent: GridComponents<GridContext>['ScrollSeekPlaceholder'] = (props) => (
<GridItem aspectRatio="1/1" {...props}>
<GalleryImagePlaceholder />
</GridItem>
);
ScrollSeekPlaceholderComponent.displayName = 'ScrollSeekPlaceholderComponent';
const components: GridComponents<GridContext> = {
Item: ItemComponent,
List: ListComponent,
ScrollSeekPlaceholder: ScrollSeekPlaceholderComponent,
};

View File

@@ -0,0 +1,21 @@
import { EMPTY_ARRAY } from 'app/store/constants';
import { useAppSelector } from 'app/store/storeHooks';
import { selectGetImageNamesQueryArgs, selectGetVideoIdsQueryArgs } from 'features/gallery/store/gallerySelectors';
import { useGetVideoIdsQuery } from 'services/api/endpoints/videos';
import { useDebounce } from 'use-debounce';
const getVideoIdsQueryOptions = {
refetchOnReconnect: true,
selectFromResult: ({ currentData, isLoading, isFetching }) => ({
videoIds: currentData?.video_ids ?? EMPTY_ARRAY,
isLoading,
isFetching,
}),
} satisfies Parameters<typeof useGetVideoIdsQuery>[1];
export const useGalleryVideoIds = () => {
const _queryArgs = useAppSelector(selectGetVideoIdsQueryArgs);
const [queryArgs] = useDebounce(_queryArgs, 300);
const { videoIds, isLoading, isFetching } = useGetVideoIdsQuery(queryArgs, getVideoIdsQueryOptions);
return { videoIds, isLoading, isFetching, queryArgs };
};

View File

@@ -0,0 +1,78 @@
import { useAppStore } from 'app/store/storeHooks';
import { useCallback, useEffect, useState } from 'react';
import type { ListRange } from 'react-virtuoso';
import { videosApi, useGetVideoDTOsByNamesMutation } from 'services/api/endpoints/videos';
import { useThrottledCallback } from 'use-debounce';
interface UseRangeBasedVideoFetchingArgs {
videoIds: string[];
enabled: boolean;
}
interface UseRangeBasedVideoFetchingReturn {
onRangeChanged: (range: ListRange) => void;
}
const getUncachedIds = (videoIds: string[], cachedVideoIds: string[], ranges: ListRange[]): string[] => {
const uncachedIdsSet = new Set<string>();
const cachedVideoIdsSet = new Set(cachedVideoIds);
for (const range of ranges) {
for (let i = range.startIndex; i <= range.endIndex; i++) {
const id = videoIds[i]!;
if (id && !cachedVideoIdsSet.has(id)) {
uncachedIdsSet.add(id);
}
}
}
return Array.from(uncachedIdsSet);
};
/**
* Hook for bulk fetching image DTOs based on the visible range from virtuoso.
* Individual image components should use `useGetImageDTOQuery(imageName)` to get their specific DTO.
* This hook ensures DTOs are bulk fetched and cached efficiently.
*/
export const useRangeBasedVideoFetching = ({
videoIds,
enabled,
}: UseRangeBasedVideoFetchingArgs): UseRangeBasedVideoFetchingReturn => {
const store = useAppStore();
const [getVideoDTOsByNames] = useGetVideoDTOsByNamesMutation();
const [lastRange, setLastRange] = useState<ListRange | null>(null);
const [pendingRanges, setPendingRanges] = useState<ListRange[]>([]);
const fetchVideos = useCallback(
(ranges: ListRange[], videoIds: string[]) => {
if (!enabled) {
return;
}
const cachedVideoIds = videosApi.util.selectCachedArgsForQuery(store.getState(), 'getVideoDTO');
const uncachedIds = getUncachedIds(videoIds, cachedVideoIds, ranges);
console.log('uncachedIds', uncachedIds);
if (uncachedIds.length === 0) {
return;
}
getVideoDTOsByNames({ video_ids: uncachedIds });
setPendingRanges([]);
},
[enabled, getVideoDTOsByNames, store]
);
const throttledFetchVideos = useThrottledCallback(fetchVideos, 500);
const onRangeChanged = useCallback((range: ListRange) => {
setLastRange(range);
setPendingRanges((prev) => [...prev, range]);
}, []);
useEffect(() => {
const combinedRanges = lastRange ? [...pendingRanges, lastRange] : pendingRanges;
throttledFetchVideos(combinedRanges, videoIds);
}, [videoIds, lastRange, pendingRanges, throttledFetchVideos]);
return {
onRangeChanged,
};
};

View File

@@ -2,7 +2,7 @@ import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { ASSETS_CATEGORIES, IMAGE_CATEGORIES } from 'features/gallery/store/types';
import type { GetImageNamesArgs, ListBoardsArgs } from 'services/api/types';
import type { GetImageNamesArgs, GetVideoIdsArgs, ListBoardsArgs } from 'services/api/types';
export const selectFirstSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(0));
export const selectLastSelectedImage = createSelector(selectGallerySlice, (gallery) => gallery.selection.at(-1));
@@ -20,9 +20,15 @@ export const selectAutoAddBoardId = createSelector(selectGallerySlice, (gallery)
export const selectAutoSwitch = createSelector(selectGallerySlice, (gallery) => gallery.shouldAutoSwitch);
export const selectSelectedBoardId = createSelector(selectGallerySlice, (gallery) => gallery.selectedBoardId);
export const selectGalleryView = createSelector(selectGallerySlice, (gallery) => gallery.galleryView);
const selectGalleryQueryCategories = createSelector(selectGalleryView, (galleryView) =>
galleryView === 'images' ? IMAGE_CATEGORIES : ASSETS_CATEGORIES
);
const selectGalleryQueryCategories = createSelector(selectGalleryView, (galleryView) => {
if (galleryView === 'images') {
return IMAGE_CATEGORIES;
}
if (galleryView === 'videos') {
return [];
}
return ASSETS_CATEGORIES;
});
const selectGallerySearchTerm = createSelector(selectGallerySlice, (gallery) => gallery.searchTerm);
const selectGalleryOrderDir = createSelector(selectGallerySlice, (gallery) => gallery.orderDir);
const selectGalleryStarredFirst = createSelector(selectGallerySlice, (gallery) => gallery.starredFirst);
@@ -44,6 +50,23 @@ export const selectGetImageNamesQueryArgs = createMemoizedSelector(
is_intermediate: false,
})
);
export const selectGetVideoIdsQueryArgs = createMemoizedSelector(
[
selectSelectedBoardId,
selectGallerySearchTerm,
selectGalleryOrderDir,
selectGalleryStarredFirst,
],
(board_id, search_term, order_dir, starred_first): GetVideoIdsArgs => ({
board_id,
search_term,
order_dir,
starred_first,
is_intermediate: false,
})
);
export const selectAutoAssignBoardOnClick = createSelector(
selectGallerySlice,
(gallery) => gallery.autoAssignBoardOnClick

View File

@@ -1,7 +1,7 @@
import type { ImageCategory } from 'services/api/types';
import z from 'zod';
const zGalleryView = z.enum(['images', 'assets']);
const zGalleryView = z.enum(['images', 'assets', 'videos']);
export type GalleryView = z.infer<typeof zGalleryView>;
const zBoardId = z.string();
// TS hack to get autocomplete for "none" but accept any string

View File

@@ -1,38 +1,54 @@
import { Box, Flex, Text } from '@invoke-ai/ui-library';
import { useFocusRegion } from 'common/hooks/focus';
import { memo, useRef } from 'react';
import { memo, useMemo, useRef } from 'react';
import { useTranslation } from 'react-i18next';
import ReactPlayer from 'react-player';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { selectGeneratedVideo } from 'features/parameters/store/videoSlice';
import { useGetVideoDTOQuery } from 'services/api/endpoints/videos';
import { skipToken } from '@reduxjs/toolkit/query';
import { useImageDTO } from 'services/api/endpoints/images';
import { selectLastSelectedImage } from 'features/gallery/store/gallerySelectors';
export const VideoPlayerPanel = memo(() => {
const { t } = useTranslation();
const ref = useRef<HTMLDivElement>(null);
const generatedVideo = useAppSelector(selectGeneratedVideo);
const lastSelectedVideoId = useAppSelector(selectLastSelectedImage);
const {data: videoDTO} = useGetVideoDTOQuery(lastSelectedVideoId ?? skipToken);
useFocusRegion('video', ref);
const videoUrl = useMemo(() => {
// if (generatedVideo) {
// return generatedVideo.video_url;
// }
if (!videoDTO) {
return null;
}
return videoDTO.video_url;
}, [videoDTO]);
return (
<Flex ref={ref} w="full" h="full" flexDirection="column" gap={4}>
{generatedVideo &&
{videoUrl &&
<>
<Box flex={0.75} position="relative" >
{/* <ReactPlayer
src={generatedVideo.url}
<ReactPlayer
src={videoUrl}
width="75%"
height="75%"
controls={true}
style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)', maxWidth: '900px' }}
/> */}
/>
</Box>
</>}
{!generatedVideo && <Text>No video generated</Text>}
{!videoUrl && <Text>No video generated</Text>}
</Flex>
);

View File

@@ -0,0 +1,88 @@
import type { paths } from 'services/api/schema';
import type {
GetVideoIdsArgs,
GetVideoIdsResult,
VideoDTO,
} from 'services/api/types';
import stableHash from 'stable-hash';
import type { Param0 } from 'tsafe';
import { api, buildV1Url, LIST_TAG } from '..';
/**
* Builds an endpoint URL for the videos router
* @example
* buildVideosUrl('some-path')
* // '/api/v1/videos/some-path'
*/
const buildVideosUrl = (path: string = '', query?: Parameters<typeof buildV1Url>[1]) =>
buildV1Url(`videos/${path}`, query);
export const videosApi = api.injectEndpoints({
endpoints: (build) => ({
/**
* Video Queries
*/
getVideoDTO: build.query<VideoDTO, string>({
query: (video_id) => ({ url: buildVideosUrl(`i/${video_id}`) }),
providesTags: (result, error, video_id) => [{ type: 'Video', id: video_id }],
}),
/**
* Get ordered list of image names for selection operations
*/
getVideoIds: build.query<GetVideoIdsResult, GetVideoIdsArgs>({
query: (queryArgs) => ({
url: buildVideosUrl('ids', queryArgs),
method: 'GET',
}),
providesTags: (result, error, queryArgs) => [
'VideoIdList',
'FetchOnReconnect',
{ type: 'VideoIdList', id: stableHash(queryArgs) },
],
}),
/**
* Get image DTOs for the specified image names. Maintains order of input names.
*/
getVideoDTOsByNames: build.mutation<
paths['/api/v1/videos/videos_by_ids']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/videos/videos_by_ids']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildVideosUrl('videos_by_ids'),
method: 'POST',
body,
}),
// Don't provide cache tags - we'll manually upsert into individual getImageDTO caches
async onQueryStarted(_, { dispatch, queryFulfilled }) {
try {
const { data: videoDTOs } = await queryFulfilled;
// Upsert each DTO into the individual image cache
const updates: Param0<typeof videosApi.util.upsertQueryEntries> = [];
for (const videoDTO of videoDTOs) {
updates.push({
endpointName: 'getVideoDTO',
arg: videoDTO.video_id,
value: videoDTO,
});
}
dispatch(videosApi.util.upsertQueryEntries(updates));
} catch {
// Handle error if needed
}
},
}),
}),
});
export const {
useGetVideoDTOQuery,
useGetVideoIdsQuery,
useGetVideoDTOsByNamesMutation,
} = videosApi;

View File

@@ -54,6 +54,8 @@ const tagTypes = [
'StylePreset',
'Schema',
'QueueCountsByDestination',
'Video',
'VideoIdList',
// This is invalidated on reconnect. It should be used for queries that have changing data,
// especially related to the queue and generation.
'FetchOnReconnect',

View File

@@ -14,6 +14,11 @@ export type GetImageNamesResult =
paths['/api/v1/images/names']['get']['responses']['200']['content']['application/json'];
export type GetImageNamesArgs = NonNullable<paths['/api/v1/images/names']['get']['parameters']['query']>;
export type GetVideoIdsResult =
paths['/api/v1/videos/ids']['get']['responses']['200']['content']['application/json'];
export type GetVideoIdsArgs = NonNullable<paths['/api/v1/videos/ids']['get']['parameters']['query']>;
export type ListBoardsArgs = NonNullable<paths['/api/v1/boards/']['get']['parameters']['query']>;
export type CreateBoardArg = paths['/api/v1/boards/']['post']['parameters']['query'];
@@ -68,6 +73,26 @@ assert<Equals<ImageDTO, S['ImageDTO']>>();
export type BoardDTO = S['BoardDTO'];
export type OffsetPaginatedResults_ImageDTO_ = S['OffsetPaginatedResults_ImageDTO_'];
// Videos
const _zVideoDTO = z.object({
video_id: z.string(),
video_url: z.string(),
thumbnail_url: z.string(),
width: z.number().int().gt(0),
height: z.number().int().gt(0),
created_at: z.string(),
updated_at: z.string(),
deleted_at: z.string().nullish(),
starred: z.boolean(),
board_id: z.string().nullish(),
is_intermediate: z.boolean(),
session_id: z.string().nullish(),
node_id: z.string().nullish(),
});
export type VideoDTO = z.infer<typeof _zVideoDTO>;
assert<Equals<VideoDTO, S['VideoDTO']>>();
export type OffsetPaginatedResults_VideoDTO_ = S['OffsetPaginatedResults_VideoDTO_'];
// Models
export type ModelType = S['ModelType'];
export type BaseModelType = S['BaseModelType'];

View File

@@ -237,7 +237,7 @@ export const buildOnInvocationComplete = (
const videoResult = await getResultVideoDTOs(data);
if (videoResult) {
dispatch(generatedVideoChanged(videoResult));
dispatch(generatedVideoChanged({ video_id: videoResult.video.video_id, type: 'video_output' }));
}
$lastProgressEvent.set(null);