mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-12 15:24:55 -05:00
refactor: optimistic gallery updates
This commit is contained in:
@@ -14,6 +14,7 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
)
|
||||
@@ -576,11 +577,11 @@ async def get_image_names(
|
||||
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
|
||||
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
|
||||
search_term: Optional[str] = Query(default=None, description="The term to search for"),
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)"""
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates"""
|
||||
|
||||
try:
|
||||
image_names = ApiDependencies.invoker.services.images.get_image_names(
|
||||
result = ApiDependencies.invoker.services.images.get_image_names(
|
||||
starred_first=starred_first,
|
||||
order_dir=order_dir,
|
||||
image_origin=image_origin,
|
||||
@@ -589,6 +590,34 @@ async def get_image_names(
|
||||
board_id=board_id,
|
||||
search_term=search_term,
|
||||
)
|
||||
return image_names
|
||||
return result
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image names")
|
||||
|
||||
|
||||
@images_router.post(
|
||||
"/images_by_names",
|
||||
operation_id="get_images_by_names",
|
||||
responses={200: {"model": list[ImageDTO]}},
|
||||
)
|
||||
async def get_images_by_names(
|
||||
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
|
||||
) -> list[ImageDTO]:
|
||||
"""Gets image DTOs for the specified image names. Maintains order of input names."""
|
||||
|
||||
try:
|
||||
image_service = ApiDependencies.invoker.services.images
|
||||
|
||||
# Fetch DTOs preserving the order of requested names
|
||||
image_dtos: list[ImageDTO] = []
|
||||
for name in image_names:
|
||||
try:
|
||||
dto = image_service.get_dto(name)
|
||||
image_dtos.append(dto)
|
||||
except Exception:
|
||||
# Skip missing images - they may have been deleted between name fetch and DTO fetch
|
||||
continue
|
||||
|
||||
return image_dtos
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Failed to get image DTOs")
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -108,6 +109,6 @@ class ImageRecordStorageBase(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names (starred first, then unstarred)."""
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -212,3 +212,10 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
|
||||
class ImageCollectionCounts(BaseModel):
|
||||
starred_count: int = Field(description="The number of starred images in the collection.")
|
||||
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
|
||||
|
||||
|
||||
class ImageNamesResult(BaseModel):
|
||||
"""Response containing ordered image names with metadata for optimistic updates."""
|
||||
image_names: list[str] = Field(description="Ordered list of image names")
|
||||
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
|
||||
total_count: int = Field(description="Total number of images matching the query")
|
||||
|
||||
@@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
IMAGE_DTO_COLS,
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -396,17 +397,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
) -> ImageNamesResult:
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
# Base query to get image names in order (starred first, then unstarred)
|
||||
query = """--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1
|
||||
"""
|
||||
|
||||
# Build query conditions (reused for both starred count and image names queries)
|
||||
query_conditions = ""
|
||||
query_params: list[Union[int, str, bool]] = []
|
||||
|
||||
@@ -451,22 +445,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
query_params.append(f"%{search_term.lower()}%")
|
||||
|
||||
# Get starred count if starred_first is enabled
|
||||
starred_count = 0
|
||||
if starred_first:
|
||||
query += (
|
||||
query_conditions
|
||||
+ f"""--sql
|
||||
starred_count_query = f"""--sql
|
||||
SELECT COUNT(*)
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE images.starred = TRUE AND (1=1{query_conditions})
|
||||
"""
|
||||
cursor.execute(starred_count_query, query_params)
|
||||
starred_count = cast(int, cursor.fetchone()[0])
|
||||
|
||||
# Get all image names with proper ordering
|
||||
if starred_first:
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.starred DESC, images.created_at {order_dir.value}
|
||||
"""
|
||||
)
|
||||
else:
|
||||
query += (
|
||||
query_conditions
|
||||
+ f"""--sql
|
||||
names_query = f"""--sql
|
||||
SELECT images.image_name
|
||||
FROM images
|
||||
LEFT JOIN board_images ON board_images.image_name = images.image_name
|
||||
WHERE 1=1{query_conditions}
|
||||
ORDER BY images.created_at {order_dir.value}
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(query, query_params)
|
||||
cursor.execute(names_query, query_params)
|
||||
result = cast(list[sqlite3.Row], cursor.fetchall())
|
||||
image_names = [row[0] for row in result]
|
||||
|
||||
return [row[0] for row in result]
|
||||
return ImageNamesResult(
|
||||
image_names=image_names,
|
||||
starred_count=starred_count,
|
||||
total_count=len(image_names)
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
|
||||
from invokeai.app.invocations.fields import MetadataField
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ResourceOrigin,
|
||||
@@ -158,6 +159,6 @@ class ImageServiceABC(ABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
"""Gets ordered list of all image names."""
|
||||
) -> ImageNamesResult:
|
||||
"""Gets ordered list of image names with metadata for optimistic updates."""
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
|
||||
)
|
||||
from invokeai.app.services.image_records.image_records_common import (
|
||||
ImageCategory,
|
||||
ImageNamesResult,
|
||||
ImageRecord,
|
||||
ImageRecordChanges,
|
||||
ImageRecordDeleteException,
|
||||
@@ -319,7 +320,7 @@ class ImageService(ImageServiceABC):
|
||||
is_intermediate: Optional[bool] = None,
|
||||
board_id: Optional[str] = None,
|
||||
search_term: Optional[str] = None,
|
||||
) -> list[str]:
|
||||
) -> ImageNamesResult:
|
||||
try:
|
||||
return self.__invoker.services.image_records.get_image_names(
|
||||
starred_first=starred_first,
|
||||
|
||||
@@ -10,7 +10,6 @@ import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddlew
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addEnsureImageIsSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/ensureImageIsSelectedListener';
|
||||
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
@@ -44,9 +43,6 @@ addImageUploadedFulfilledListener(startAppListening);
|
||||
// Image deleted
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
|
||||
// Gallery
|
||||
addGalleryImageClickedListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { uniq } from 'es-toolkit/compat';
|
||||
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
|
||||
export const galleryImageClicked = createAction<{
|
||||
imageName: string;
|
||||
shiftKey: boolean;
|
||||
ctrlKey: boolean;
|
||||
metaKey: boolean;
|
||||
altKey: boolean;
|
||||
}>('gallery/imageClicked');
|
||||
|
||||
/**
|
||||
* This listener handles the logic for selecting images in the gallery.
|
||||
*
|
||||
* Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time
|
||||
* the selection changed, the callback got recreated and all images rerendered. This could easily block for
|
||||
* hundreds of ms, more for lower end devices.
|
||||
*
|
||||
* Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery
|
||||
* is much more responsive.
|
||||
*/
|
||||
|
||||
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: galleryImageClicked,
|
||||
effect: (action, { dispatch, getState }) => {
|
||||
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImageNamesQueryArgs(state);
|
||||
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? [];
|
||||
|
||||
// If we don't have the image names cached, we can't perform selection operations
|
||||
// This can happen if the user clicks on an image before the names are loaded
|
||||
if (imageNames.length === 0) {
|
||||
// For basic click without modifiers, we can still set selection
|
||||
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
|
||||
dispatch(selectionChanged([imageName]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
if (state.gallery.imageToCompare === imageName) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(imageToCompareChanged(imageName));
|
||||
}
|
||||
} else if (shiftKey) {
|
||||
const rangeEndImageName = imageName;
|
||||
const lastSelectedImage = selection.at(-1);
|
||||
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
|
||||
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
|
||||
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
|
||||
// We have a valid range!
|
||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||
const imagesToSelect = imageNames.slice(start, end + 1);
|
||||
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
|
||||
}
|
||||
} else if (ctrlKey || metaKey) {
|
||||
if (selection.some((n) => n === imageName) && selection.length > 1) {
|
||||
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
|
||||
} else {
|
||||
dispatch(selectionChanged(uniq(selection.concat(imageName))));
|
||||
}
|
||||
} else {
|
||||
dispatch(selectionChanged([imageName]));
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -3,9 +3,10 @@ import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Image } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { galleryImageClicked } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import type { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { uniq } from 'es-toolkit';
|
||||
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
|
||||
import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPreviewMultipleImage';
|
||||
import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage';
|
||||
@@ -15,11 +16,13 @@ import { firefoxDndFix } from 'features/dnd/util';
|
||||
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons';
|
||||
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
|
||||
import { imageToCompareChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
|
||||
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
|
||||
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
|
||||
import type { MouseEventHandler } from 'react';
|
||||
import type { MouseEvent, MouseEventHandler } from 'react';
|
||||
import { memo, useCallback, useEffect, useId, useMemo, useRef, useState } from 'react';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
// This class name is used to calculate the number of images that fit in the gallery
|
||||
@@ -83,6 +86,54 @@ interface Props {
|
||||
imageDTO: ImageDTO;
|
||||
}
|
||||
|
||||
const buildOnClick =
|
||||
(imageName: string, dispatch: AppDispatch, getState: AppGetState) => (e: MouseEvent<HTMLDivElement>) => {
|
||||
const { shiftKey, ctrlKey, metaKey, altKey } = e;
|
||||
const state = getState();
|
||||
const queryArgs = selectListImageNamesQueryArgs(state);
|
||||
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data?.image_names ?? [];
|
||||
|
||||
// If we don't have the image names cached, we can't perform selection operations
|
||||
// This can happen if the user clicks on an image before the names are loaded
|
||||
if (imageNames.length === 0) {
|
||||
// For basic click without modifiers, we can still set selection
|
||||
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
|
||||
dispatch(selectionChanged([imageName]));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const selection = state.gallery.selection;
|
||||
|
||||
if (altKey) {
|
||||
if (state.gallery.imageToCompare === imageName) {
|
||||
dispatch(imageToCompareChanged(null));
|
||||
} else {
|
||||
dispatch(imageToCompareChanged(imageName));
|
||||
}
|
||||
} else if (shiftKey) {
|
||||
const rangeEndImageName = imageName;
|
||||
const lastSelectedImage = selection.at(-1);
|
||||
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
|
||||
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
|
||||
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
|
||||
// We have a valid range!
|
||||
const start = Math.min(lastClickedIndex, currentClickedIndex);
|
||||
const end = Math.max(lastClickedIndex, currentClickedIndex);
|
||||
const imagesToSelect = imageNames.slice(start, end + 1);
|
||||
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
|
||||
}
|
||||
} else if (ctrlKey || metaKey) {
|
||||
if (selection.some((n) => n === imageName) && selection.length > 1) {
|
||||
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
|
||||
} else {
|
||||
dispatch(selectionChanged(uniq(selection.concat(imageName))));
|
||||
}
|
||||
} else {
|
||||
dispatch(selectionChanged([imageName]));
|
||||
}
|
||||
};
|
||||
|
||||
export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
const store = useAppStore();
|
||||
const autoLayoutContext = useAutoLayoutContext();
|
||||
@@ -192,20 +243,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
setIsHovered(false);
|
||||
}, []);
|
||||
|
||||
const onClick = useCallback<MouseEventHandler<HTMLDivElement>>(
|
||||
(e) => {
|
||||
store.dispatch(
|
||||
galleryImageClicked({
|
||||
imageName: imageDTO.image_name,
|
||||
shiftKey: e.shiftKey,
|
||||
ctrlKey: e.ctrlKey,
|
||||
metaKey: e.metaKey,
|
||||
altKey: e.altKey,
|
||||
})
|
||||
);
|
||||
},
|
||||
[imageDTO, store]
|
||||
);
|
||||
const onClick = useMemo(() => buildOnClick(imageDTO.image_name, store.dispatch, store.getState), [imageDTO, store]);
|
||||
|
||||
const onDoubleClick = useCallback<MouseEventHandler<HTMLDivElement>>(() => {
|
||||
store.dispatch(imageToCompareChanged(null));
|
||||
@@ -238,6 +276,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
ref={ref}
|
||||
src={imageDTO.thumbnail_url}
|
||||
w={imageDTO.width}
|
||||
fallback={<GalleryImagePlaceholder />}
|
||||
objectFit="contain"
|
||||
maxW="full"
|
||||
maxH="full"
|
||||
@@ -253,3 +292,5 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
|
||||
});
|
||||
|
||||
GalleryImage.displayName = 'GalleryImage';
|
||||
|
||||
export const GalleryImagePlaceholder = memo(() => <Box w="full" h="full" bg="base.850" borderRadius="base" />);
|
||||
|
||||
@@ -2,15 +2,16 @@ import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
|
||||
import type { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
|
||||
import {
|
||||
LIMIT,
|
||||
selectGalleryImageMinimumWidth,
|
||||
selectImageToCompare,
|
||||
selectLastSelectedImage,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
|
||||
import { useOverlayScrollbars } from 'overlayscrollbars-react';
|
||||
import type { MutableRefObject, RefObject } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
@@ -23,22 +24,15 @@ import type {
|
||||
VirtuosoGridHandle,
|
||||
} from 'react-virtuoso';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
import { useListImagesQuery } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
import { GalleryImage } from './ImageGrid/GalleryImage';
|
||||
import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
|
||||
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
|
||||
import { useGalleryImageNames } from './use-gallery-image-names';
|
||||
|
||||
const log = logger('gallery');
|
||||
|
||||
// Constants
|
||||
const VIEWPORT_BUFFER = 2048;
|
||||
const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096;
|
||||
const DEBOUNCE_DELAY = 500;
|
||||
const SPINNER_OPACITY = 0.3;
|
||||
|
||||
type ListImageNamesQueryArgs = ReturnType<typeof selectListImageNamesQueryArgs>;
|
||||
|
||||
type GridContext = {
|
||||
@@ -46,58 +40,41 @@ type GridContext = {
|
||||
imageNames: string[];
|
||||
};
|
||||
|
||||
// Hook to get an image DTO from cache or trigger loading
|
||||
const useImageDTOFromListQuery = (
|
||||
index: number,
|
||||
imageName: string,
|
||||
queryArgs: ListImageNamesQueryArgs
|
||||
): ImageDTO | null => {
|
||||
const { arg, options } = useMemo(() => {
|
||||
const pageOffset = Math.floor(index / LIMIT) * LIMIT;
|
||||
return {
|
||||
arg: {
|
||||
...queryArgs,
|
||||
offset: pageOffset,
|
||||
limit: LIMIT,
|
||||
} satisfies Parameters<typeof useListImagesQuery>[0],
|
||||
options: {
|
||||
selectFromResult: ({ data }) => {
|
||||
const imageDTO = data?.items?.[index - pageOffset] || null;
|
||||
if (imageDTO && imageDTO.image_name !== imageName) {
|
||||
log.warn(`Image at index ${index} does not match expected image name ${imageName}`);
|
||||
return { imageDTO: null };
|
||||
}
|
||||
return { imageDTO };
|
||||
},
|
||||
} satisfies Parameters<typeof useListImagesQuery>[1],
|
||||
};
|
||||
}, [index, queryArgs, imageName]);
|
||||
const ImageAtPosition = memo(({ imageName }: { index: number; imageName: string }) => {
|
||||
/*
|
||||
* We rely on the useRangeBasedImageFetching to fetch all image DTOs, caching them with RTK Query.
|
||||
*
|
||||
* In this component, we just want to consume that cache. Unforutnately, RTK Query does not provide a way to
|
||||
* subscribe to a query without triggering a new fetch.
|
||||
*
|
||||
* There is a hack, though:
|
||||
* - https://github.com/reduxjs/redux-toolkit/discussions/4213
|
||||
*
|
||||
* This essentially means "subscribe to the query once it has some data".
|
||||
*/
|
||||
|
||||
const { imageDTO } = useListImagesQuery(arg, options);
|
||||
// Use `currentData` instead of `data` to prevent a flash of previous image rendered at this index
|
||||
const { currentData: imageDTO, isUninitialized } = imagesApi.endpoints.getImageDTO.useQueryState(imageName);
|
||||
imagesApi.endpoints.getImageDTO.useQuerySubscription(imageName, { skip: isUninitialized });
|
||||
|
||||
return imageDTO;
|
||||
};
|
||||
|
||||
// Individual image component that gets its data from RTK Query cache
|
||||
const ImageAtPosition = memo(
|
||||
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => {
|
||||
const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs);
|
||||
|
||||
if (!imageDTO) {
|
||||
return <Skeleton w="full" h="full" />;
|
||||
}
|
||||
|
||||
return <GalleryImage imageDTO={imageDTO} />;
|
||||
if (!imageDTO) {
|
||||
return <GalleryImagePlaceholder />;
|
||||
}
|
||||
);
|
||||
|
||||
return <GalleryImage imageDTO={imageDTO} />;
|
||||
});
|
||||
ImageAtPosition.displayName = 'ImageAtPosition';
|
||||
|
||||
// Memoized compute key function using image names
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (_index, imageName, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${imageName}`;
|
||||
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
|
||||
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
|
||||
};
|
||||
|
||||
// Physical DOM-based grid calculation using refs (based on working old implementation)
|
||||
/**
|
||||
* Calculate how many images fit in a row based on the current grid layout.
|
||||
*
|
||||
* TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value
|
||||
* changes. Cache this calculation.
|
||||
*/
|
||||
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
// Start from root and find virtuoso grid elements
|
||||
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
|
||||
@@ -124,7 +101,14 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Use the exact calculation from the working old implementation
|
||||
/**
|
||||
* You might be tempted to just do some simple math like:
|
||||
* const imagesPerRow = Math.floor(containerRect.width / itemRect.width);
|
||||
*
|
||||
* But floating point precision can cause issues with this approach, causing it to be off by 1 in some cases.
|
||||
*
|
||||
* Instead, we use a more robust approach that iteratively calculates how many images fit in the row.
|
||||
*/
|
||||
let imagesPerRow = 0;
|
||||
let spaceUsed = 0;
|
||||
|
||||
@@ -141,7 +125,9 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
|
||||
return Math.max(1, imagesPerRow);
|
||||
};
|
||||
|
||||
// Check if an item at a given index is visible in the viewport
|
||||
/**
|
||||
* Scroll the item at the given index into view if it is not currently visible.
|
||||
*/
|
||||
const scrollIntoView = (
|
||||
index: number,
|
||||
rootEl: HTMLDivElement,
|
||||
@@ -202,6 +188,11 @@ const scrollIntoView = (
|
||||
return;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get the index of the image in the list of image names.
|
||||
* If the image name is not found, return 0.
|
||||
* If no image name is provided, return 0.
|
||||
*/
|
||||
const getImageIndex = (imageName: string | undefined | null, imageNames: string[]) => {
|
||||
if (!imageName || imageNames.length === 0) {
|
||||
return 0;
|
||||
@@ -210,7 +201,9 @@ const getImageIndex = (imageName: string | undefined | null, imageNames: string[
|
||||
return index >= 0 ? index : 0;
|
||||
};
|
||||
|
||||
// Hook for keyboard navigation using physical DOM measurements
|
||||
/**
|
||||
* Handles keyboard navigation for the gallery.
|
||||
*/
|
||||
const useKeyboardNavigation = (
|
||||
imageNames: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
@@ -249,11 +242,12 @@ const useKeyboardNavigation = (
|
||||
|
||||
event.preventDefault();
|
||||
|
||||
const state = getState();
|
||||
const imageName = event.altKey
|
||||
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
|
||||
// we start from the last selected image
|
||||
(selectImageToCompare(getState()) ?? selectLastSelectedImage(getState()))
|
||||
: selectLastSelectedImage(getState());
|
||||
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
|
||||
: selectLastSelectedImage(state);
|
||||
|
||||
const currentIndex = getImageIndex(imageName, imageNames);
|
||||
|
||||
@@ -373,6 +367,11 @@ const useKeyboardNavigation = (
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Keeps the last selected image in view when the gallery is scrolled.
|
||||
* This is useful for keyboard navigation and ensuring the user can see their selection.
|
||||
* It only tracks the last selected image, not the image to compare.
|
||||
*/
|
||||
const useKeepSelectedImageInView = (
|
||||
imageNames: string[],
|
||||
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
|
||||
@@ -397,6 +396,9 @@ const useKeepSelectedImageInView = (
|
||||
}, [imageName, imageNames, rangeRef, rootRef, virtuosoRef]);
|
||||
};
|
||||
|
||||
/**
|
||||
* Handles the initialization of the overlay scrollbars for the gallery, returning the ref to the scroller element.
|
||||
*/
|
||||
const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
|
||||
const [initialize, osInstance] = useOverlayScrollbars({
|
||||
@@ -431,43 +433,49 @@ const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
|
||||
return scrollerRef;
|
||||
};
|
||||
|
||||
// Main gallery component
|
||||
export const NewGallery = memo(() => {
|
||||
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
|
||||
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const rootRef = useRef<HTMLDivElement>(null);
|
||||
const { isActiveTab } = useAutoLayoutContext();
|
||||
|
||||
// Get the ordered list of image names - this is our primary data source for virtualization
|
||||
const { queryArgs, imageNames, isLoading } = useGalleryImageNames();
|
||||
|
||||
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
|
||||
const { onRangeChanged } = useRangeBasedImageFetching({
|
||||
imageNames,
|
||||
enabled: !isLoading && isActiveTab,
|
||||
});
|
||||
|
||||
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
|
||||
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
|
||||
const scrollerRef = useScrollableGallery(rootRef);
|
||||
|
||||
// We have to keep track of the visible range for keep-selected-image-in-view functionality
|
||||
const handleRangeChanged = useCallback((range: ListRange) => {
|
||||
rangeRef.current = range;
|
||||
}, []);
|
||||
|
||||
const context = useMemo(
|
||||
() =>
|
||||
({
|
||||
imageNames,
|
||||
queryArgs,
|
||||
}) satisfies GridContext,
|
||||
[imageNames, queryArgs]
|
||||
/*
|
||||
* We have to keep track of the visible range for keep-selected-image-in-view functionality and push the range to
|
||||
* the range-based image fetching hook.
|
||||
*/
|
||||
const handleRangeChanged = useCallback(
|
||||
(range: ListRange) => {
|
||||
rangeRef.current = range;
|
||||
onRangeChanged(range);
|
||||
},
|
||||
[onRangeChanged]
|
||||
);
|
||||
|
||||
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
|
||||
|
||||
// Item content function
|
||||
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName, ctx) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} queryArgs={ctx.queryArgs} />;
|
||||
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
|
||||
return <ImageAtPosition index={index} imageName={imageName} />;
|
||||
}, []);
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<Flex height="100%" alignItems="center" justifyContent="center">
|
||||
<Spinner size="lg" opacity={SPINNER_OPACITY} />
|
||||
<Text ml={4}>Loading gallery...</Text>
|
||||
<Flex height="100%" alignItems="center" justifyContent="center" gap={4}>
|
||||
<Spinner size="lg" opacity={0.3} />
|
||||
<Text color="base.300">Loading gallery...</Text>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
@@ -481,12 +489,13 @@ export const NewGallery = memo(() => {
|
||||
}
|
||||
|
||||
return (
|
||||
// This wrapper component is necessary to initialize the overlay scrollbars!
|
||||
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
|
||||
<VirtuosoGrid<string, GridContext>
|
||||
ref={virtuosoRef}
|
||||
context={context}
|
||||
data={imageNames}
|
||||
increaseViewportBy={VIEWPORT_BUFFER}
|
||||
increaseViewportBy={2048}
|
||||
itemContent={itemContent}
|
||||
computeItemKey={computeItemKey}
|
||||
components={components}
|
||||
@@ -503,7 +512,7 @@ export const NewGallery = memo(() => {
|
||||
NewGallery.displayName = 'NewGallery';
|
||||
|
||||
const scrollSeekConfiguration: ScrollSeekConfiguration = {
|
||||
enter: (velocity) => velocity > SCROLL_SEEK_VELOCITY_THRESHOLD,
|
||||
enter: (velocity) => velocity > 4096,
|
||||
exit: (velocity) => velocity === 0,
|
||||
};
|
||||
|
||||
@@ -518,7 +527,7 @@ const selectGridTemplateColumns = createSelector(
|
||||
// Grid components
|
||||
const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context: _, ...rest }, ref) => {
|
||||
const _gridTemplateColumns = useAppSelector(selectGridTemplateColumns);
|
||||
const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, DEBOUNCE_DELAY);
|
||||
const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, 300);
|
||||
|
||||
return <Grid ref={ref} gridTemplateColumns={gridTemplateColumns} gap={1} {...rest} />;
|
||||
});
|
||||
|
||||
@@ -5,8 +5,8 @@ import { useGetImageNamesQuery } from 'services/api/endpoints/images';
|
||||
import { useDebounce } from 'use-debounce';
|
||||
|
||||
const getImageNamesQueryOptions = {
|
||||
selectFromResult: ({ data, isLoading, isFetching }) => ({
|
||||
imageNames: data ?? EMPTY_ARRAY,
|
||||
selectFromResult: ({ currentData, isLoading, isFetching }) => ({
|
||||
imageNames: currentData?.image_names ?? EMPTY_ARRAY,
|
||||
isLoading,
|
||||
isFetching,
|
||||
}),
|
||||
@@ -14,7 +14,7 @@ const getImageNamesQueryOptions = {
|
||||
|
||||
export const useGalleryImageNames = () => {
|
||||
const _queryArgs = useAppSelector(selectListImageNamesQueryArgs);
|
||||
const [queryArgs] = useDebounce(_queryArgs, 500);
|
||||
const [queryArgs] = useDebounce(_queryArgs, 300);
|
||||
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions);
|
||||
return { imageNames, isLoading, isFetching, queryArgs };
|
||||
};
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { useCallback, useEffect, useState } from 'react';
|
||||
import type { ListRange } from 'react-virtuoso';
|
||||
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
|
||||
import { useThrottledCallback } from 'use-debounce';
|
||||
|
||||
interface UseRangeBasedImageFetchingArgs {
|
||||
imageNames: string[];
|
||||
enabled: boolean;
|
||||
}
|
||||
|
||||
interface UseRangeBasedImageFetchingReturn {
|
||||
onRangeChanged: (range: ListRange) => void;
|
||||
}
|
||||
|
||||
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
|
||||
if (range.startIndex === range.endIndex) {
|
||||
// If the start and end indices are the same, no range to fetch
|
||||
return [];
|
||||
}
|
||||
|
||||
if (imageNames.length === 0) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const start = Math.max(0, range.startIndex);
|
||||
const end = Math.min(imageNames.length - 1, range.endIndex);
|
||||
|
||||
if (cachedImageNames.length === 0) {
|
||||
return imageNames.slice(start, end + 1);
|
||||
}
|
||||
|
||||
const uncachedNames: string[] = [];
|
||||
|
||||
for (let i = start; i <= end; i++) {
|
||||
const imageName = imageNames[i]!;
|
||||
if (!cachedImageNames.includes(imageName)) {
|
||||
uncachedNames.push(imageName);
|
||||
}
|
||||
}
|
||||
|
||||
return uncachedNames;
|
||||
};
|
||||
|
||||
/**
|
||||
* Hook for bulk fetching image DTOs based on the visible range from virtuoso.
|
||||
* Individual image components should use `useGetImageDTOQuery(imageName)` to get their specific DTO.
|
||||
* This hook ensures DTOs are bulk fetched and cached efficiently.
|
||||
*/
|
||||
export const useRangeBasedImageFetching = ({
|
||||
imageNames,
|
||||
enabled,
|
||||
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
|
||||
const store = useAppStore();
|
||||
const [visibleRange, setVisibleRange] = useState<ListRange>({ startIndex: 0, endIndex: 0 });
|
||||
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
|
||||
|
||||
const fetchImages = useCallback(
|
||||
(visibleRange: ListRange) => {
|
||||
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
|
||||
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
|
||||
if (uncachedNames.length === 0) {
|
||||
return;
|
||||
}
|
||||
getImageDTOsByNames({ image_names: uncachedNames });
|
||||
},
|
||||
[getImageDTOsByNames, imageNames, store]
|
||||
);
|
||||
|
||||
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
|
||||
|
||||
useEffect(() => {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
throttledFetchImages(visibleRange);
|
||||
}, [enabled, throttledFetchImages, imageNames, visibleRange]);
|
||||
|
||||
const onRangeChanged = useCallback((range: ListRange) => {
|
||||
setVisibleRange(range);
|
||||
}, []);
|
||||
|
||||
return {
|
||||
onRangeChanged,
|
||||
};
|
||||
};
|
||||
@@ -1,5 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import type { DockviewApi, GridviewApi } from 'dockview';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import type { TabName } from 'features/ui/store/uiTypes';
|
||||
import type { WritableAtom } from 'nanostores';
|
||||
import { atom } from 'nanostores';
|
||||
import type { PropsWithChildren, RefObject } from 'react';
|
||||
@@ -8,6 +12,7 @@ import { createContext, memo, useCallback, useContext, useMemo, useState } from
|
||||
import { LEFT_PANEL_ID, LEFT_PANEL_MIN_SIZE_PX, RIGHT_PANEL_ID, RIGHT_PANEL_MIN_SIZE_PX } from './shared';
|
||||
|
||||
type AutoLayoutContextValue = {
|
||||
isActiveTab: boolean;
|
||||
toggleLeftPanel: () => void;
|
||||
toggleRightPanel: () => void;
|
||||
toggleBothPanels: () => void;
|
||||
@@ -57,9 +62,15 @@ const activatePanel = (api: GridviewApi | DockviewApi, panelId: string) => {
|
||||
};
|
||||
|
||||
export const AutoLayoutProvider = (
|
||||
props: PropsWithChildren<{ $rootApi: WritableAtom<GridviewApi | null>; rootRef: RefObject<HTMLDivElement> }>
|
||||
props: PropsWithChildren<{
|
||||
$rootApi: WritableAtom<GridviewApi | null>;
|
||||
rootRef: RefObject<HTMLDivElement>;
|
||||
tab: TabName;
|
||||
}>
|
||||
) => {
|
||||
const { $rootApi, rootRef, children } = props;
|
||||
const { $rootApi, rootRef, tab, children } = props;
|
||||
const selectIsActiveTab = useMemo(() => createSelector(selectActiveTab, (activeTab) => activeTab === tab), [tab]);
|
||||
const isActiveTab = useAppSelector(selectIsActiveTab);
|
||||
const $leftApi = useState(() => atom<GridviewApi | null>(null))[0];
|
||||
const $centerApi = useState(() => atom<DockviewApi | null>(null))[0];
|
||||
const $rightApi = useState(() => atom<GridviewApi | null>(null))[0];
|
||||
@@ -126,6 +137,7 @@ export const AutoLayoutProvider = (
|
||||
|
||||
const value = useMemo<AutoLayoutContextValue>(
|
||||
() => ({
|
||||
isActiveTab,
|
||||
toggleLeftPanel,
|
||||
toggleRightPanel,
|
||||
toggleBothPanels,
|
||||
@@ -138,6 +150,7 @@ export const AutoLayoutProvider = (
|
||||
_$rightPanelApi: $rightApi,
|
||||
}),
|
||||
[
|
||||
isActiveTab,
|
||||
$centerApi,
|
||||
$leftApi,
|
||||
$rightApi,
|
||||
|
||||
@@ -259,7 +259,7 @@ export const CanvasTabAutoLayout = memo(() => {
|
||||
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
|
||||
|
||||
return (
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="canvas">
|
||||
<GridviewReact
|
||||
ref={rootRef}
|
||||
className="dockview-theme-invoke"
|
||||
|
||||
@@ -234,7 +234,7 @@ export const GenerateTabAutoLayout = memo(() => {
|
||||
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
|
||||
|
||||
return (
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="generate">
|
||||
<GridviewReact
|
||||
ref={rootRef}
|
||||
className="dockview-theme-invoke"
|
||||
|
||||
@@ -229,7 +229,7 @@ export const UpscalingTabAutoLayout = memo(() => {
|
||||
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
|
||||
|
||||
return (
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="upscaling">
|
||||
<GridviewReact
|
||||
ref={rootRef}
|
||||
className="dockview-theme-invoke"
|
||||
|
||||
@@ -247,7 +247,7 @@ export const WorkflowsTabAutoLayout = memo(() => {
|
||||
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
|
||||
|
||||
return (
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
|
||||
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="workflows">
|
||||
<GridviewReact
|
||||
ref={rootRef}
|
||||
className="dockview-theme-invoke"
|
||||
|
||||
@@ -7,6 +7,7 @@ import type {
|
||||
GraphAndWorkflowResponse,
|
||||
ImageCategory,
|
||||
ImageDTO,
|
||||
ImageNamesResult,
|
||||
ImageUploadEntryRequest,
|
||||
ImageUploadEntryResponse,
|
||||
ListImagesArgs,
|
||||
@@ -431,7 +432,7 @@ export const imagesApi = api.injectEndpoints({
|
||||
* Get ordered list of image names for selection operations
|
||||
*/
|
||||
getImageNames: build.query<
|
||||
string[],
|
||||
ImageNamesResult,
|
||||
{
|
||||
categories?: ImageCategory[] | null;
|
||||
is_intermediate?: boolean | null;
|
||||
@@ -450,6 +451,38 @@ export const imagesApi = api.injectEndpoints({
|
||||
{ type: 'ImageNameList', id: stableHash(queryArgs) },
|
||||
],
|
||||
}),
|
||||
/**
|
||||
* Get image DTOs for the specified image names. Maintains order of input names.
|
||||
*/
|
||||
getImageDTOsByNames: build.mutation<
|
||||
paths['/api/v1/images/images_by_names']['post']['responses']['200']['content']['application/json'],
|
||||
paths['/api/v1/images/images_by_names']['post']['requestBody']['content']['application/json']
|
||||
>({
|
||||
query: (body) => ({
|
||||
url: buildImagesUrl('images_by_names'),
|
||||
method: 'POST',
|
||||
body,
|
||||
}),
|
||||
// Don't provide cache tags - we'll manually upsert into individual getImageDTO caches
|
||||
async onQueryStarted(_, { dispatch, queryFulfilled }) {
|
||||
try {
|
||||
const { data: imageDTOs } = await queryFulfilled;
|
||||
|
||||
// Upsert each DTO into the individual image cache
|
||||
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
|
||||
for (const imageDTO of imageDTOs) {
|
||||
updates.push({
|
||||
endpointName: 'getImageDTO',
|
||||
arg: imageDTO.image_name,
|
||||
value: imageDTO,
|
||||
});
|
||||
}
|
||||
dispatch(imagesApi.util.upsertQueryEntries(updates));
|
||||
} catch {
|
||||
// Handle error if needed
|
||||
}
|
||||
},
|
||||
}),
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -472,6 +505,7 @@ export const {
|
||||
useUnstarImagesMutation,
|
||||
useBulkDownloadImagesMutation,
|
||||
useGetImageNamesQuery,
|
||||
useGetImageDTOsByNamesMutation,
|
||||
} = imagesApi;
|
||||
|
||||
/**
|
||||
|
||||
@@ -761,7 +761,7 @@ export type paths = {
|
||||
};
|
||||
/**
|
||||
* Get Image Names
|
||||
* @description Gets ordered list of all image names (starred first, then unstarred)
|
||||
* @description Gets ordered list of image names with metadata for optimistic updates
|
||||
*/
|
||||
get: operations["get_image_names"];
|
||||
put?: never;
|
||||
@@ -772,6 +772,26 @@ export type paths = {
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/images/images_by_names": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
get?: never;
|
||||
put?: never;
|
||||
/**
|
||||
* Get Images By Names
|
||||
* @description Gets image DTOs for the specified image names. Maintains order of input names.
|
||||
*/
|
||||
post: operations["get_images_by_names"];
|
||||
delete?: never;
|
||||
options?: never;
|
||||
head?: never;
|
||||
patch?: never;
|
||||
trace?: never;
|
||||
};
|
||||
"/api/v1/boards/": {
|
||||
parameters: {
|
||||
query?: never;
|
||||
@@ -2648,6 +2668,14 @@ export type components = {
|
||||
/** @description The validation run data to use for this batch. This is only used if this is a validation run. */
|
||||
validation_run_data?: components["schemas"]["ValidationRunData"] | null;
|
||||
};
|
||||
/** Body_get_images_by_names */
|
||||
Body_get_images_by_names: {
|
||||
/**
|
||||
* Image Names
|
||||
* @description Object containing list of image names to fetch DTOs for
|
||||
*/
|
||||
image_names: string[];
|
||||
};
|
||||
/** Body_import_style_presets */
|
||||
Body_import_style_presets: {
|
||||
/**
|
||||
@@ -10479,6 +10507,27 @@ export type components = {
|
||||
*/
|
||||
type: "img_nsfw";
|
||||
};
|
||||
/**
|
||||
* ImageNamesResult
|
||||
* @description Response containing ordered image names with metadata for optimistic updates.
|
||||
*/
|
||||
ImageNamesResult: {
|
||||
/**
|
||||
* Image Names
|
||||
* @description Ordered list of image names
|
||||
*/
|
||||
image_names: string[];
|
||||
/**
|
||||
* Starred Count
|
||||
* @description Number of starred images (when starred_first=True)
|
||||
*/
|
||||
starred_count: number;
|
||||
/**
|
||||
* Total Count
|
||||
* @description Total number of images matching the query
|
||||
*/
|
||||
total_count: number;
|
||||
};
|
||||
/**
|
||||
* Add Image Noise
|
||||
* @description Add noise to an image
|
||||
@@ -23725,7 +23774,40 @@ export interface operations {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": string[];
|
||||
"application/json": components["schemas"]["ImageNamesResult"];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
422: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["HTTPValidationError"];
|
||||
};
|
||||
};
|
||||
};
|
||||
};
|
||||
get_images_by_names: {
|
||||
parameters: {
|
||||
query?: never;
|
||||
header?: never;
|
||||
path?: never;
|
||||
cookie?: never;
|
||||
};
|
||||
requestBody: {
|
||||
content: {
|
||||
"application/json": components["schemas"]["Body_get_images_by_names"];
|
||||
};
|
||||
};
|
||||
responses: {
|
||||
/** @description Successful Response */
|
||||
200: {
|
||||
headers: {
|
||||
[name: string]: unknown;
|
||||
};
|
||||
content: {
|
||||
"application/json": components["schemas"]["ImageDTO"][];
|
||||
};
|
||||
};
|
||||
/** @description Validation Error */
|
||||
|
||||
@@ -7,6 +7,8 @@ export type S = components['schemas'];
|
||||
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
|
||||
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
|
||||
|
||||
export type ImageNamesResult = S['ImageNamesResult'];
|
||||
|
||||
export type ListBoardsArgs = NonNullable<paths['/api/v1/boards/']['get']['parameters']['query']>;
|
||||
|
||||
export type DeleteBoardResult =
|
||||
|
||||
@@ -0,0 +1,90 @@
|
||||
import type { ImageDTO, ImageNamesResult } from 'services/api/types';
|
||||
|
||||
/**
|
||||
* Calculates the optimal insertion position for a new image in the names list.
|
||||
* For starred_first=true: starred images go to position 0, unstarred go after all starred images
|
||||
* For starred_first=false: all new images go to position 0 (newest first)
|
||||
*/
|
||||
export function calculateImageInsertionPosition(
|
||||
imageDTO: ImageDTO,
|
||||
starredFirst: boolean,
|
||||
starredCount: number
|
||||
): number {
|
||||
if (!starredFirst) {
|
||||
// When starred_first is false, always insert at the beginning (newest first)
|
||||
return 0;
|
||||
}
|
||||
|
||||
// When starred_first is true
|
||||
if (imageDTO.starred) {
|
||||
// Starred images go at the very beginning
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Unstarred images go after all starred images
|
||||
return starredCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistically inserts a new image into the ImageNamesResult at the correct position
|
||||
*/
|
||||
export function insertImageIntoNamesResult(
|
||||
currentResult: ImageNamesResult,
|
||||
imageDTO: ImageDTO,
|
||||
starredFirst: boolean
|
||||
): ImageNamesResult {
|
||||
// Don't insert if the image is already in the list
|
||||
if (currentResult.image_names.includes(imageDTO.image_name)) {
|
||||
return currentResult;
|
||||
}
|
||||
|
||||
const insertPosition = calculateImageInsertionPosition(imageDTO, starredFirst, currentResult.starred_count);
|
||||
|
||||
const newImageNames = [...currentResult.image_names];
|
||||
newImageNames.splice(insertPosition, 0, imageDTO.image_name);
|
||||
|
||||
return {
|
||||
image_names: newImageNames,
|
||||
starred_count: starredFirst && imageDTO.starred ? currentResult.starred_count + 1 : currentResult.starred_count,
|
||||
total_count: currentResult.total_count + 1,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistically removes an image from the ImageNamesResult
|
||||
*/
|
||||
export function removeImageFromNamesResult(
|
||||
currentResult: ImageNamesResult,
|
||||
imageNameToRemove: string,
|
||||
wasStarred: boolean,
|
||||
starredFirst: boolean
|
||||
): ImageNamesResult {
|
||||
const newImageNames = currentResult.image_names.filter((name) => name !== imageNameToRemove);
|
||||
|
||||
return {
|
||||
image_names: newImageNames,
|
||||
starred_count: starredFirst && wasStarred ? currentResult.starred_count - 1 : currentResult.starred_count,
|
||||
total_count: currentResult.total_count - 1,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimistically updates an image's position in the result when its starred status changes
|
||||
*/
|
||||
export function updateImagePositionInNamesResult(
|
||||
currentResult: ImageNamesResult,
|
||||
updatedImageDTO: ImageDTO,
|
||||
previouslyStarred: boolean,
|
||||
starredFirst: boolean
|
||||
): ImageNamesResult {
|
||||
// First remove the image from its current position
|
||||
const withoutImage = removeImageFromNamesResult(
|
||||
currentResult,
|
||||
updatedImageDTO.image_name,
|
||||
previouslyStarred,
|
||||
starredFirst
|
||||
);
|
||||
|
||||
// Then insert it at the new correct position
|
||||
return insertImageIntoNamesResult(withoutImage, updatedImageDTO, starredFirst);
|
||||
}
|
||||
@@ -1,24 +1,22 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
|
||||
import type { AppDispatch, AppGetState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import {
|
||||
selectAutoSwitch,
|
||||
selectGalleryView,
|
||||
selectListImagesBaseQueryArgs,
|
||||
selectListImageNamesQueryArgs,
|
||||
selectSelectedBoardId,
|
||||
} from 'features/gallery/store/gallerySelectors';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import type { ApiTagDescription } from 'services/api';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories } from 'services/api/util';
|
||||
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
|
||||
import { $lastProgressEvent } from 'services/events/stores';
|
||||
import stableHash from 'stable-hash';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { objectEntries } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
@@ -42,9 +40,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
|
||||
// We'll keep track of each change we need to make and do them all at once.
|
||||
const boardTotalAdditions: Record<string, number> = {};
|
||||
const boardTagIdsToInvalidate: Set<string> = new Set();
|
||||
const imageListTagIdsToInvalidate: Set<string> = new Set();
|
||||
const listImagesArg = selectListImagesBaseQueryArgs(getState());
|
||||
const listImageNamesArg = selectListImageNamesQueryArgs(getState());
|
||||
|
||||
for (const imageDTO of imageDTOs) {
|
||||
if (imageDTO.is_intermediate) {
|
||||
@@ -54,17 +50,6 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
const board_id = imageDTO.board_id ?? 'none';
|
||||
// update the total images for the board
|
||||
boardTotalAdditions[board_id] = (boardTotalAdditions[board_id] || 0) + 1;
|
||||
// invalidate the board tag
|
||||
boardTagIdsToInvalidate.add(board_id);
|
||||
// invalidate the image list tag
|
||||
imageListTagIdsToInvalidate.add(
|
||||
stableHash({
|
||||
...listImagesArg,
|
||||
categories: getCategories(imageDTO),
|
||||
board_id,
|
||||
offset: 0,
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Update all the board image totals at once
|
||||
@@ -85,16 +70,40 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
}
|
||||
dispatch(boardsApi.util.upsertQueryEntries(entries));
|
||||
|
||||
// Invalidate all tags at once
|
||||
const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({
|
||||
type: 'Board' as const,
|
||||
id: boardId,
|
||||
}));
|
||||
const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({
|
||||
type: 'ImageList' as const,
|
||||
id: imageListId,
|
||||
}));
|
||||
dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags]));
|
||||
// Optimistically update image names lists - DTOs are already cached by getResultImageDTOs
|
||||
const state = getState();
|
||||
|
||||
for (const imageDTO of imageDTOs) {
|
||||
// Construct the expected query args for this image's getImageNames query
|
||||
// Use the current gallery query args as base, but override board_id and categories for this specific image
|
||||
const expectedQueryArgs = {
|
||||
...listImageNamesArg,
|
||||
categories: getCategories(imageDTO),
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
};
|
||||
|
||||
// Check if we have cached image names for this query
|
||||
const cachedNamesResult = imagesApi.endpoints.getImageNames.select(expectedQueryArgs)(state);
|
||||
|
||||
if (cachedNamesResult.data) {
|
||||
// We have cached names - optimistically insert the new image
|
||||
dispatch(
|
||||
imagesApi.util.updateQueryData('getImageNames', expectedQueryArgs, (draft) => {
|
||||
// Use the utility function to insert at the correct position
|
||||
const updatedResult = insertImageIntoNamesResult(draft, imageDTO, expectedQueryArgs.starred_first ?? true);
|
||||
|
||||
// Replace the draft contents
|
||||
draft.image_names = updatedResult.image_names;
|
||||
draft.starred_count = updatedResult.starred_count;
|
||||
draft.total_count = updatedResult.total_count;
|
||||
})
|
||||
);
|
||||
}
|
||||
// If no cached data, we don't need to do anything - there's no list to update
|
||||
}
|
||||
|
||||
// No need to invalidate tags since we're doing optimistic updates
|
||||
// Board totals are already updated above via upsertQueryEntries
|
||||
|
||||
const autoSwitch = selectAutoSwitch(getState());
|
||||
|
||||
@@ -112,63 +121,27 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
|
||||
const { image_name } = lastImageDTO;
|
||||
const board_id = lastImageDTO.board_id ?? 'none';
|
||||
|
||||
/**
|
||||
* Auto-switch needs a bit of care to avoid race conditions - we need to invalidate the appropriate image list
|
||||
* query cache, and only after it has loaded, select the new image.
|
||||
*/
|
||||
const queryArgs = {
|
||||
...listImagesArg,
|
||||
categories: getCategories(lastImageDTO),
|
||||
board_id,
|
||||
offset: 0,
|
||||
};
|
||||
// With optimistic updates, we can immediately switch to the new image
|
||||
const selectedBoardId = selectSelectedBoardId(getState());
|
||||
|
||||
dispatch(
|
||||
addAppListener({
|
||||
predicate: (action) => {
|
||||
if (!imagesApi.endpoints.listImages.matchFulfilled(action)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (stableHash(action.meta.arg.originalArgs) !== stableHash(queryArgs)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
},
|
||||
effect: (_action, { getState, dispatch, unsubscribe }) => {
|
||||
// This is a one-time listener - we always unsubscribe after the first match
|
||||
unsubscribe();
|
||||
|
||||
// Auto-switch may have been disabled while we were waiting for the query to resolve - bail if so
|
||||
const autoSwitch = selectAutoSwitch(getState());
|
||||
if (!autoSwitch) {
|
||||
return;
|
||||
}
|
||||
|
||||
const selectedBoardId = selectSelectedBoardId(getState());
|
||||
|
||||
// If the image is from a different board, switch to that board & select the image - otherwise just select the
|
||||
// image. This implicitly changes the view to 'images' if it was not already.
|
||||
if (board_id !== selectedBoardId) {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: board_id,
|
||||
selectedImageName: image_name,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
// Ensure we are on the 'images' gallery view - that's where this image will be displayed
|
||||
const galleryView = selectGalleryView(getState());
|
||||
if (galleryView !== 'images') {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
// Else just select the image, no need to switch boards
|
||||
dispatch(imageSelected(lastImageDTO.image_name));
|
||||
}
|
||||
},
|
||||
})
|
||||
);
|
||||
// If the image is from a different board, switch to that board & select the image - otherwise just select the
|
||||
// image. This implicitly changes the view to 'images' if it was not already.
|
||||
if (board_id !== selectedBoardId) {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: board_id,
|
||||
selectedImageName: image_name,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
// Ensure we are on the 'images' gallery view - that's where this image will be displayed
|
||||
const galleryView = selectGalleryView(getState());
|
||||
if (galleryView !== 'images') {
|
||||
dispatch(galleryViewChanged('images'));
|
||||
}
|
||||
// Select the image immediately since we've optimistically updated the cache
|
||||
dispatch(imageSelected(lastImageDTO.image_name));
|
||||
}
|
||||
};
|
||||
|
||||
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {
|
||||
|
||||
Reference in New Issue
Block a user