refactor: optimistic gallery updates

This commit is contained in:
psychedelicious
2025-06-26 16:44:51 +10:00
parent 504daa0ae5
commit ab5cb2c264
22 changed files with 605 additions and 303 deletions

View File

@@ -14,6 +14,7 @@ from invokeai.app.api.extract_metadata_from_image import extract_metadata_from_i
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecordChanges,
ResourceOrigin,
)
@@ -576,11 +577,11 @@ async def get_image_names(
order_dir: SQLiteDirection = Query(default=SQLiteDirection.Descending, description="The order of sort"),
starred_first: bool = Query(default=True, description="Whether to sort by starred images first"),
search_term: Optional[str] = Query(default=None, description="The term to search for"),
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)"""
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates"""
try:
image_names = ApiDependencies.invoker.services.images.get_image_names(
result = ApiDependencies.invoker.services.images.get_image_names(
starred_first=starred_first,
order_dir=order_dir,
image_origin=image_origin,
@@ -589,6 +590,34 @@ async def get_image_names(
board_id=board_id,
search_term=search_term,
)
return image_names
return result
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image names")
@images_router.post(
"/images_by_names",
operation_id="get_images_by_names",
responses={200: {"model": list[ImageDTO]}},
)
async def get_images_by_names(
image_names: list[str] = Body(embed=True, description="Object containing list of image names to fetch DTOs for"),
) -> list[ImageDTO]:
"""Gets image DTOs for the specified image names. Maintains order of input names."""
try:
image_service = ApiDependencies.invoker.services.images
# Fetch DTOs preserving the order of requested names
image_dtos: list[ImageDTO] = []
for name in image_names:
try:
dto = image_service.get_dto(name)
image_dtos.append(dto)
except Exception:
# Skip missing images - they may have been deleted between name fetch and DTO fetch
continue
return image_dtos
except Exception:
raise HTTPException(status_code=500, detail="Failed to get image DTOs")

View File

@@ -5,6 +5,7 @@ from typing import Optional
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -108,6 +109,6 @@ class ImageRecordStorageBase(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
"""Gets ordered list of all image names (starred first, then unstarred)."""
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -212,3 +212,10 @@ def deserialize_image_record(image_dict: dict) -> ImageRecord:
class ImageCollectionCounts(BaseModel):
starred_count: int = Field(description="The number of starred images in the collection.")
unstarred_count: int = Field(description="The number of unstarred images in the collection.")
class ImageNamesResult(BaseModel):
"""Response containing ordered image names with metadata for optimistic updates."""
image_names: list[str] = Field(description="Ordered list of image names")
starred_count: int = Field(description="Number of starred images (when starred_first=True)")
total_count: int = Field(description="Total number of images matching the query")

View File

@@ -7,6 +7,7 @@ from invokeai.app.services.image_records.image_records_base import ImageRecordSt
from invokeai.app.services.image_records.image_records_common import (
IMAGE_DTO_COLS,
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -396,17 +397,10 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
) -> ImageNamesResult:
cursor = self._conn.cursor()
# Base query to get image names in order (starred first, then unstarred)
query = """--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1
"""
# Build query conditions (reused for both starred count and image names queries)
query_conditions = ""
query_params: list[Union[int, str, bool]] = []
@@ -451,22 +445,42 @@ class SqliteImageRecordStorage(ImageRecordStorageBase):
query_params.append(f"%{search_term.lower()}%")
query_params.append(f"%{search_term.lower()}%")
# Get starred count if starred_first is enabled
starred_count = 0
if starred_first:
query += (
query_conditions
+ f"""--sql
starred_count_query = f"""--sql
SELECT COUNT(*)
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE images.starred = TRUE AND (1=1{query_conditions})
"""
cursor.execute(starred_count_query, query_params)
starred_count = cast(int, cursor.fetchone()[0])
# Get all image names with proper ordering
if starred_first:
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.starred DESC, images.created_at {order_dir.value}
"""
)
else:
query += (
query_conditions
+ f"""--sql
names_query = f"""--sql
SELECT images.image_name
FROM images
LEFT JOIN board_images ON board_images.image_name = images.image_name
WHERE 1=1{query_conditions}
ORDER BY images.created_at {order_dir.value}
"""
)
cursor.execute(query, query_params)
cursor.execute(names_query, query_params)
result = cast(list[sqlite3.Row], cursor.fetchall())
image_names = [row[0] for row in result]
return [row[0] for row in result]
return ImageNamesResult(
image_names=image_names,
starred_count=starred_count,
total_count=len(image_names)
)

View File

@@ -6,6 +6,7 @@ from PIL.Image import Image as PILImageType
from invokeai.app.invocations.fields import MetadataField
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ResourceOrigin,
@@ -158,6 +159,6 @@ class ImageServiceABC(ABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
"""Gets ordered list of all image names."""
) -> ImageNamesResult:
"""Gets ordered list of image names with metadata for optimistic updates."""
pass

View File

@@ -10,6 +10,7 @@ from invokeai.app.services.image_files.image_files_common import (
)
from invokeai.app.services.image_records.image_records_common import (
ImageCategory,
ImageNamesResult,
ImageRecord,
ImageRecordChanges,
ImageRecordDeleteException,
@@ -319,7 +320,7 @@ class ImageService(ImageServiceABC):
is_intermediate: Optional[bool] = None,
board_id: Optional[str] = None,
search_term: Optional[str] = None,
) -> list[str]:
) -> ImageNamesResult:
try:
return self.__invoker.services.image_records.get_image_names(
starred_first=starred_first,

View File

@@ -10,7 +10,6 @@ import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddlew
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addEnsureImageIsSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/ensureImageIsSelectedListener';
import { addGalleryImageClickedListener } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
@@ -44,9 +43,6 @@ addImageUploadedFulfilledListener(startAppListening);
// Image deleted
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// Gallery
addGalleryImageClickedListener(startAppListening);
// User Invoked
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);

View File

@@ -1,77 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { uniq } from 'es-toolkit/compat';
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { imagesApi } from 'services/api/endpoints/images';
export const galleryImageClicked = createAction<{
imageName: string;
shiftKey: boolean;
ctrlKey: boolean;
metaKey: boolean;
altKey: boolean;
}>('gallery/imageClicked');
/**
* This listener handles the logic for selecting images in the gallery.
*
* Previously, this logic was in a `useCallback` with the whole gallery selection as a dependency. Every time
* the selection changed, the callback got recreated and all images rerendered. This could easily block for
* hundreds of ms, more for lower end devices.
*
* Moving this logic into a listener means we don't need to recalculate anything dynamically and the gallery
* is much more responsive.
*/
export const addGalleryImageClickedListener = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: galleryImageClicked,
effect: (action, { dispatch, getState }) => {
const { imageName, shiftKey, ctrlKey, metaKey, altKey } = action.payload;
const state = getState();
const queryArgs = selectListImageNamesQueryArgs(state);
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data ?? [];
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded
if (imageNames.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([imageName]));
}
return;
}
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare === imageName) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageName));
}
} else if (shiftKey) {
const rangeEndImageName = imageName;
const lastSelectedImage = selection.at(-1);
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = imageNames.slice(start, end + 1);
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
}
} else if (ctrlKey || metaKey) {
if (selection.some((n) => n === imageName) && selection.length > 1) {
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
} else {
dispatch(selectionChanged(uniq(selection.concat(imageName))));
}
} else {
dispatch(selectionChanged([imageName]));
}
},
});
};

View File

@@ -3,9 +3,10 @@ import { draggable, monitorForElements } from '@atlaskit/pragmatic-drag-and-drop
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Image } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { galleryImageClicked } from 'app/store/middleware/listenerMiddleware/listeners/galleryImageClicked';
import { useAppStore } from 'app/store/nanostores/store';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { uniq } from 'es-toolkit';
import { multipleImageDndSource, singleImageDndSource } from 'features/dnd/dnd';
import type { DndDragPreviewMultipleImageState } from 'features/dnd/DndDragPreviewMultipleImage';
import { createMultipleImageDragPreview, setMultipleImageDragPreview } from 'features/dnd/DndDragPreviewMultipleImage';
@@ -15,11 +16,13 @@ import { firefoxDndFix } from 'features/dnd/util';
import { useImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { GalleryImageHoverIcons } from 'features/gallery/components/ImageGrid/GalleryImageHoverIcons';
import { getGalleryImageDataTestId } from 'features/gallery/components/ImageGrid/getGalleryImageDataTestId';
import { imageToCompareChanged, selectGallerySlice } from 'features/gallery/store/gallerySlice';
import { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectGallerySlice, selectionChanged } from 'features/gallery/store/gallerySlice';
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
import { VIEWER_PANEL_ID } from 'features/ui/layouts/shared';
import type { MouseEventHandler } from 'react';
import type { MouseEvent, MouseEventHandler } from 'react';
import { memo, useCallback, useEffect, useId, useMemo, useRef, useState } from 'react';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
// This class name is used to calculate the number of images that fit in the gallery
@@ -83,6 +86,54 @@ interface Props {
imageDTO: ImageDTO;
}
const buildOnClick =
(imageName: string, dispatch: AppDispatch, getState: AppGetState) => (e: MouseEvent<HTMLDivElement>) => {
const { shiftKey, ctrlKey, metaKey, altKey } = e;
const state = getState();
const queryArgs = selectListImageNamesQueryArgs(state);
const imageNames = imagesApi.endpoints.getImageNames.select(queryArgs)(state).data?.image_names ?? [];
// If we don't have the image names cached, we can't perform selection operations
// This can happen if the user clicks on an image before the names are loaded
if (imageNames.length === 0) {
// For basic click without modifiers, we can still set selection
if (!shiftKey && !ctrlKey && !metaKey && !altKey) {
dispatch(selectionChanged([imageName]));
}
return;
}
const selection = state.gallery.selection;
if (altKey) {
if (state.gallery.imageToCompare === imageName) {
dispatch(imageToCompareChanged(null));
} else {
dispatch(imageToCompareChanged(imageName));
}
} else if (shiftKey) {
const rangeEndImageName = imageName;
const lastSelectedImage = selection.at(-1);
const lastClickedIndex = imageNames.findIndex((name) => name === lastSelectedImage);
const currentClickedIndex = imageNames.findIndex((name) => name === rangeEndImageName);
if (lastClickedIndex > -1 && currentClickedIndex > -1) {
// We have a valid range!
const start = Math.min(lastClickedIndex, currentClickedIndex);
const end = Math.max(lastClickedIndex, currentClickedIndex);
const imagesToSelect = imageNames.slice(start, end + 1);
dispatch(selectionChanged(uniq(selection.concat(imagesToSelect))));
}
} else if (ctrlKey || metaKey) {
if (selection.some((n) => n === imageName) && selection.length > 1) {
dispatch(selectionChanged(uniq(selection.filter((n) => n !== imageName))));
} else {
dispatch(selectionChanged(uniq(selection.concat(imageName))));
}
} else {
dispatch(selectionChanged([imageName]));
}
};
export const GalleryImage = memo(({ imageDTO }: Props) => {
const store = useAppStore();
const autoLayoutContext = useAutoLayoutContext();
@@ -192,20 +243,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
setIsHovered(false);
}, []);
const onClick = useCallback<MouseEventHandler<HTMLDivElement>>(
(e) => {
store.dispatch(
galleryImageClicked({
imageName: imageDTO.image_name,
shiftKey: e.shiftKey,
ctrlKey: e.ctrlKey,
metaKey: e.metaKey,
altKey: e.altKey,
})
);
},
[imageDTO, store]
);
const onClick = useMemo(() => buildOnClick(imageDTO.image_name, store.dispatch, store.getState), [imageDTO, store]);
const onDoubleClick = useCallback<MouseEventHandler<HTMLDivElement>>(() => {
store.dispatch(imageToCompareChanged(null));
@@ -238,6 +276,7 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
ref={ref}
src={imageDTO.thumbnail_url}
w={imageDTO.width}
fallback={<GalleryImagePlaceholder />}
objectFit="contain"
maxW="full"
maxH="full"
@@ -253,3 +292,5 @@ export const GalleryImage = memo(({ imageDTO }: Props) => {
});
GalleryImage.displayName = 'GalleryImage';
export const GalleryImagePlaceholder = memo(() => <Box w="full" h="full" bg="base.850" borderRadius="base" />);

View File

@@ -2,15 +2,16 @@ import { Box, Flex, forwardRef, Grid, GridItem, Skeleton, Spinner, Text } from '
import { createSelector } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import { useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useRangeBasedImageFetching } from 'features/gallery/hooks/useRangeBasedImageFetching';
import type { selectListImageNamesQueryArgs } from 'features/gallery/store/gallerySelectors';
import {
LIMIT,
selectGalleryImageMinimumWidth,
selectImageToCompare,
selectLastSelectedImage,
} from 'features/gallery/store/gallerySelectors';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { useAutoLayoutContext } from 'features/ui/layouts/auto-layout-context';
import { useOverlayScrollbars } from 'overlayscrollbars-react';
import type { MutableRefObject, RefObject } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
@@ -23,22 +24,15 @@ import type {
VirtuosoGridHandle,
} from 'react-virtuoso';
import { VirtuosoGrid } from 'react-virtuoso';
import { useListImagesQuery } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import { imagesApi } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
import { GalleryImage } from './ImageGrid/GalleryImage';
import { GalleryImage, GalleryImagePlaceholder } from './ImageGrid/GalleryImage';
import { GallerySelectionCountTag } from './ImageGrid/GallerySelectionCountTag';
import { useGalleryImageNames } from './use-gallery-image-names';
const log = logger('gallery');
// Constants
const VIEWPORT_BUFFER = 2048;
const SCROLL_SEEK_VELOCITY_THRESHOLD = 4096;
const DEBOUNCE_DELAY = 500;
const SPINNER_OPACITY = 0.3;
type ListImageNamesQueryArgs = ReturnType<typeof selectListImageNamesQueryArgs>;
type GridContext = {
@@ -46,58 +40,41 @@ type GridContext = {
imageNames: string[];
};
// Hook to get an image DTO from cache or trigger loading
const useImageDTOFromListQuery = (
index: number,
imageName: string,
queryArgs: ListImageNamesQueryArgs
): ImageDTO | null => {
const { arg, options } = useMemo(() => {
const pageOffset = Math.floor(index / LIMIT) * LIMIT;
return {
arg: {
...queryArgs,
offset: pageOffset,
limit: LIMIT,
} satisfies Parameters<typeof useListImagesQuery>[0],
options: {
selectFromResult: ({ data }) => {
const imageDTO = data?.items?.[index - pageOffset] || null;
if (imageDTO && imageDTO.image_name !== imageName) {
log.warn(`Image at index ${index} does not match expected image name ${imageName}`);
return { imageDTO: null };
}
return { imageDTO };
},
} satisfies Parameters<typeof useListImagesQuery>[1],
};
}, [index, queryArgs, imageName]);
const ImageAtPosition = memo(({ imageName }: { index: number; imageName: string }) => {
/*
* We rely on the useRangeBasedImageFetching to fetch all image DTOs, caching them with RTK Query.
*
* In this component, we just want to consume that cache. Unforutnately, RTK Query does not provide a way to
* subscribe to a query without triggering a new fetch.
*
* There is a hack, though:
* - https://github.com/reduxjs/redux-toolkit/discussions/4213
*
* This essentially means "subscribe to the query once it has some data".
*/
const { imageDTO } = useListImagesQuery(arg, options);
// Use `currentData` instead of `data` to prevent a flash of previous image rendered at this index
const { currentData: imageDTO, isUninitialized } = imagesApi.endpoints.getImageDTO.useQueryState(imageName);
imagesApi.endpoints.getImageDTO.useQuerySubscription(imageName, { skip: isUninitialized });
return imageDTO;
};
// Individual image component that gets its data from RTK Query cache
const ImageAtPosition = memo(
({ index, queryArgs, imageName }: { index: number; imageName: string; queryArgs: ListImageNamesQueryArgs }) => {
const imageDTO = useImageDTOFromListQuery(index, imageName, queryArgs);
if (!imageDTO) {
return <Skeleton w="full" h="full" />;
}
return <GalleryImage imageDTO={imageDTO} />;
if (!imageDTO) {
return <GalleryImagePlaceholder />;
}
);
return <GalleryImage imageDTO={imageDTO} />;
});
ImageAtPosition.displayName = 'ImageAtPosition';
// Memoized compute key function using image names
const computeItemKey: GridComputeItemKey<string, GridContext> = (_index, imageName, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${imageName}`;
const computeItemKey: GridComputeItemKey<string, GridContext> = (index, imageName, { queryArgs }) => {
return `${JSON.stringify(queryArgs)}-${imageName ?? index}`;
};
// Physical DOM-based grid calculation using refs (based on working old implementation)
/**
* Calculate how many images fit in a row based on the current grid layout.
*
* TODO(psyche): We only need to do this when the gallery width changes, or when the galleryImageMinimumWidth value
* changes. Cache this calculation.
*/
const getImagesPerRow = (rootEl: HTMLDivElement): number => {
// Start from root and find virtuoso grid elements
const gridElement = rootEl.querySelector('.virtuoso-grid-list');
@@ -124,7 +101,14 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
return 0;
}
// Use the exact calculation from the working old implementation
/**
* You might be tempted to just do some simple math like:
* const imagesPerRow = Math.floor(containerRect.width / itemRect.width);
*
* But floating point precision can cause issues with this approach, causing it to be off by 1 in some cases.
*
* Instead, we use a more robust approach that iteratively calculates how many images fit in the row.
*/
let imagesPerRow = 0;
let spaceUsed = 0;
@@ -141,7 +125,9 @@ const getImagesPerRow = (rootEl: HTMLDivElement): number => {
return Math.max(1, imagesPerRow);
};
// Check if an item at a given index is visible in the viewport
/**
* Scroll the item at the given index into view if it is not currently visible.
*/
const scrollIntoView = (
index: number,
rootEl: HTMLDivElement,
@@ -202,6 +188,11 @@ const scrollIntoView = (
return;
};
/**
* Get the index of the image in the list of image names.
* If the image name is not found, return 0.
* If no image name is provided, return 0.
*/
const getImageIndex = (imageName: string | undefined | null, imageNames: string[]) => {
if (!imageName || imageNames.length === 0) {
return 0;
@@ -210,7 +201,9 @@ const getImageIndex = (imageName: string | undefined | null, imageNames: string[
return index >= 0 ? index : 0;
};
// Hook for keyboard navigation using physical DOM measurements
/**
* Handles keyboard navigation for the gallery.
*/
const useKeyboardNavigation = (
imageNames: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
@@ -249,11 +242,12 @@ const useKeyboardNavigation = (
event.preventDefault();
const state = getState();
const imageName = event.altKey
? // When the user holds alt, we are changing the image to compare - if no image to compare is currently selected,
// we start from the last selected image
(selectImageToCompare(getState()) ?? selectLastSelectedImage(getState()))
: selectLastSelectedImage(getState());
(selectImageToCompare(state) ?? selectLastSelectedImage(state))
: selectLastSelectedImage(state);
const currentIndex = getImageIndex(imageName, imageNames);
@@ -373,6 +367,11 @@ const useKeyboardNavigation = (
});
};
/**
* Keeps the last selected image in view when the gallery is scrolled.
* This is useful for keyboard navigation and ensuring the user can see their selection.
* It only tracks the last selected image, not the image to compare.
*/
const useKeepSelectedImageInView = (
imageNames: string[],
virtuosoRef: React.RefObject<VirtuosoGridHandle>,
@@ -397,6 +396,9 @@ const useKeepSelectedImageInView = (
}, [imageName, imageNames, rangeRef, rootRef, virtuosoRef]);
};
/**
* Handles the initialization of the overlay scrollbars for the gallery, returning the ref to the scroller element.
*/
const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
const [scroller, scrollerRef] = useState<HTMLElement | null>(null);
const [initialize, osInstance] = useOverlayScrollbars({
@@ -431,43 +433,49 @@ const useScrollableGallery = (rootRef: RefObject<HTMLDivElement>) => {
return scrollerRef;
};
// Main gallery component
export const NewGallery = memo(() => {
const virtuosoRef = useRef<VirtuosoGridHandle>(null);
const rangeRef = useRef<ListRange>({ startIndex: 0, endIndex: 0 });
const rootRef = useRef<HTMLDivElement>(null);
const { isActiveTab } = useAutoLayoutContext();
// Get the ordered list of image names - this is our primary data source for virtualization
const { queryArgs, imageNames, isLoading } = useGalleryImageNames();
// Use range-based fetching for bulk loading image DTOs into cache based on the visible range
const { onRangeChanged } = useRangeBasedImageFetching({
imageNames,
enabled: !isLoading && isActiveTab,
});
useKeepSelectedImageInView(imageNames, virtuosoRef, rootRef, rangeRef);
useKeyboardNavigation(imageNames, virtuosoRef, rootRef);
const scrollerRef = useScrollableGallery(rootRef);
// We have to keep track of the visible range for keep-selected-image-in-view functionality
const handleRangeChanged = useCallback((range: ListRange) => {
rangeRef.current = range;
}, []);
const context = useMemo(
() =>
({
imageNames,
queryArgs,
}) satisfies GridContext,
[imageNames, queryArgs]
/*
* We have to keep track of the visible range for keep-selected-image-in-view functionality and push the range to
* the range-based image fetching hook.
*/
const handleRangeChanged = useCallback(
(range: ListRange) => {
rangeRef.current = range;
onRangeChanged(range);
},
[onRangeChanged]
);
const context = useMemo<GridContext>(() => ({ imageNames, queryArgs }), [imageNames, queryArgs]);
// Item content function
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName, ctx) => {
return <ImageAtPosition index={index} imageName={imageName} queryArgs={ctx.queryArgs} />;
const itemContent: GridItemContent<string, GridContext> = useCallback((index, imageName) => {
return <ImageAtPosition index={index} imageName={imageName} />;
}, []);
if (isLoading) {
return (
<Flex height="100%" alignItems="center" justifyContent="center">
<Spinner size="lg" opacity={SPINNER_OPACITY} />
<Text ml={4}>Loading gallery...</Text>
<Flex height="100%" alignItems="center" justifyContent="center" gap={4}>
<Spinner size="lg" opacity={0.3} />
<Text color="base.300">Loading gallery...</Text>
</Flex>
);
}
@@ -481,12 +489,13 @@ export const NewGallery = memo(() => {
}
return (
// This wrapper component is necessary to initialize the overlay scrollbars!
<Box data-overlayscrollbars-initialize="" ref={rootRef} position="relative" w="full" h="full">
<VirtuosoGrid<string, GridContext>
ref={virtuosoRef}
context={context}
data={imageNames}
increaseViewportBy={VIEWPORT_BUFFER}
increaseViewportBy={2048}
itemContent={itemContent}
computeItemKey={computeItemKey}
components={components}
@@ -503,7 +512,7 @@ export const NewGallery = memo(() => {
NewGallery.displayName = 'NewGallery';
const scrollSeekConfiguration: ScrollSeekConfiguration = {
enter: (velocity) => velocity > SCROLL_SEEK_VELOCITY_THRESHOLD,
enter: (velocity) => velocity > 4096,
exit: (velocity) => velocity === 0,
};
@@ -518,7 +527,7 @@ const selectGridTemplateColumns = createSelector(
// Grid components
const ListComponent: GridComponents<GridContext>['List'] = forwardRef(({ context: _, ...rest }, ref) => {
const _gridTemplateColumns = useAppSelector(selectGridTemplateColumns);
const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, DEBOUNCE_DELAY);
const [gridTemplateColumns] = useDebounce(_gridTemplateColumns, 300);
return <Grid ref={ref} gridTemplateColumns={gridTemplateColumns} gap={1} {...rest} />;
});

View File

@@ -5,8 +5,8 @@ import { useGetImageNamesQuery } from 'services/api/endpoints/images';
import { useDebounce } from 'use-debounce';
const getImageNamesQueryOptions = {
selectFromResult: ({ data, isLoading, isFetching }) => ({
imageNames: data ?? EMPTY_ARRAY,
selectFromResult: ({ currentData, isLoading, isFetching }) => ({
imageNames: currentData?.image_names ?? EMPTY_ARRAY,
isLoading,
isFetching,
}),
@@ -14,7 +14,7 @@ const getImageNamesQueryOptions = {
export const useGalleryImageNames = () => {
const _queryArgs = useAppSelector(selectListImageNamesQueryArgs);
const [queryArgs] = useDebounce(_queryArgs, 500);
const [queryArgs] = useDebounce(_queryArgs, 300);
const { imageNames, isLoading, isFetching } = useGetImageNamesQuery(queryArgs, getImageNamesQueryOptions);
return { imageNames, isLoading, isFetching, queryArgs };
};

View File

@@ -0,0 +1,86 @@
import { useAppStore } from 'app/store/storeHooks';
import { useCallback, useEffect, useState } from 'react';
import type { ListRange } from 'react-virtuoso';
import { imagesApi, useGetImageDTOsByNamesMutation } from 'services/api/endpoints/images';
import { useThrottledCallback } from 'use-debounce';
interface UseRangeBasedImageFetchingArgs {
imageNames: string[];
enabled: boolean;
}
interface UseRangeBasedImageFetchingReturn {
onRangeChanged: (range: ListRange) => void;
}
const getUncachedNames = (imageNames: string[], cachedImageNames: string[], range: ListRange): string[] => {
if (range.startIndex === range.endIndex) {
// If the start and end indices are the same, no range to fetch
return [];
}
if (imageNames.length === 0) {
return [];
}
const start = Math.max(0, range.startIndex);
const end = Math.min(imageNames.length - 1, range.endIndex);
if (cachedImageNames.length === 0) {
return imageNames.slice(start, end + 1);
}
const uncachedNames: string[] = [];
for (let i = start; i <= end; i++) {
const imageName = imageNames[i]!;
if (!cachedImageNames.includes(imageName)) {
uncachedNames.push(imageName);
}
}
return uncachedNames;
};
/**
* Hook for bulk fetching image DTOs based on the visible range from virtuoso.
* Individual image components should use `useGetImageDTOQuery(imageName)` to get their specific DTO.
* This hook ensures DTOs are bulk fetched and cached efficiently.
*/
export const useRangeBasedImageFetching = ({
imageNames,
enabled,
}: UseRangeBasedImageFetchingArgs): UseRangeBasedImageFetchingReturn => {
const store = useAppStore();
const [visibleRange, setVisibleRange] = useState<ListRange>({ startIndex: 0, endIndex: 0 });
const [getImageDTOsByNames] = useGetImageDTOsByNamesMutation();
const fetchImages = useCallback(
(visibleRange: ListRange) => {
const cachedImageNames = imagesApi.util.selectCachedArgsForQuery(store.getState(), 'getImageDTO');
const uncachedNames = getUncachedNames(imageNames, cachedImageNames, visibleRange);
if (uncachedNames.length === 0) {
return;
}
getImageDTOsByNames({ image_names: uncachedNames });
},
[getImageDTOsByNames, imageNames, store]
);
const throttledFetchImages = useThrottledCallback(fetchImages, 100);
useEffect(() => {
if (!enabled) {
return;
}
throttledFetchImages(visibleRange);
}, [enabled, throttledFetchImages, imageNames, visibleRange]);
const onRangeChanged = useCallback((range: ListRange) => {
setVisibleRange(range);
}, []);
return {
onRangeChanged,
};
};

View File

@@ -1,5 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import type { DockviewApi, GridviewApi } from 'dockview';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import type { TabName } from 'features/ui/store/uiTypes';
import type { WritableAtom } from 'nanostores';
import { atom } from 'nanostores';
import type { PropsWithChildren, RefObject } from 'react';
@@ -8,6 +12,7 @@ import { createContext, memo, useCallback, useContext, useMemo, useState } from
import { LEFT_PANEL_ID, LEFT_PANEL_MIN_SIZE_PX, RIGHT_PANEL_ID, RIGHT_PANEL_MIN_SIZE_PX } from './shared';
type AutoLayoutContextValue = {
isActiveTab: boolean;
toggleLeftPanel: () => void;
toggleRightPanel: () => void;
toggleBothPanels: () => void;
@@ -57,9 +62,15 @@ const activatePanel = (api: GridviewApi | DockviewApi, panelId: string) => {
};
export const AutoLayoutProvider = (
props: PropsWithChildren<{ $rootApi: WritableAtom<GridviewApi | null>; rootRef: RefObject<HTMLDivElement> }>
props: PropsWithChildren<{
$rootApi: WritableAtom<GridviewApi | null>;
rootRef: RefObject<HTMLDivElement>;
tab: TabName;
}>
) => {
const { $rootApi, rootRef, children } = props;
const { $rootApi, rootRef, tab, children } = props;
const selectIsActiveTab = useMemo(() => createSelector(selectActiveTab, (activeTab) => activeTab === tab), [tab]);
const isActiveTab = useAppSelector(selectIsActiveTab);
const $leftApi = useState(() => atom<GridviewApi | null>(null))[0];
const $centerApi = useState(() => atom<DockviewApi | null>(null))[0];
const $rightApi = useState(() => atom<GridviewApi | null>(null))[0];
@@ -126,6 +137,7 @@ export const AutoLayoutProvider = (
const value = useMemo<AutoLayoutContextValue>(
() => ({
isActiveTab,
toggleLeftPanel,
toggleRightPanel,
toggleBothPanels,
@@ -138,6 +150,7 @@ export const AutoLayoutProvider = (
_$rightPanelApi: $rightApi,
}),
[
isActiveTab,
$centerApi,
$leftApi,
$rightApi,

View File

@@ -259,7 +259,7 @@ export const CanvasTabAutoLayout = memo(() => {
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
return (
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="canvas">
<GridviewReact
ref={rootRef}
className="dockview-theme-invoke"

View File

@@ -234,7 +234,7 @@ export const GenerateTabAutoLayout = memo(() => {
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
return (
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="generate">
<GridviewReact
ref={rootRef}
className="dockview-theme-invoke"

View File

@@ -229,7 +229,7 @@ export const UpscalingTabAutoLayout = memo(() => {
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
return (
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="upscaling">
<GridviewReact
ref={rootRef}
className="dockview-theme-invoke"

View File

@@ -247,7 +247,7 @@ export const WorkflowsTabAutoLayout = memo(() => {
useResizeMainPanelOnFirstVisit($rootPanelApi, rootRef);
return (
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef}>
<AutoLayoutProvider $rootApi={$rootPanelApi} rootRef={rootRef} tab="workflows">
<GridviewReact
ref={rootRef}
className="dockview-theme-invoke"

View File

@@ -7,6 +7,7 @@ import type {
GraphAndWorkflowResponse,
ImageCategory,
ImageDTO,
ImageNamesResult,
ImageUploadEntryRequest,
ImageUploadEntryResponse,
ListImagesArgs,
@@ -431,7 +432,7 @@ export const imagesApi = api.injectEndpoints({
* Get ordered list of image names for selection operations
*/
getImageNames: build.query<
string[],
ImageNamesResult,
{
categories?: ImageCategory[] | null;
is_intermediate?: boolean | null;
@@ -450,6 +451,38 @@ export const imagesApi = api.injectEndpoints({
{ type: 'ImageNameList', id: stableHash(queryArgs) },
],
}),
/**
* Get image DTOs for the specified image names. Maintains order of input names.
*/
getImageDTOsByNames: build.mutation<
paths['/api/v1/images/images_by_names']['post']['responses']['200']['content']['application/json'],
paths['/api/v1/images/images_by_names']['post']['requestBody']['content']['application/json']
>({
query: (body) => ({
url: buildImagesUrl('images_by_names'),
method: 'POST',
body,
}),
// Don't provide cache tags - we'll manually upsert into individual getImageDTO caches
async onQueryStarted(_, { dispatch, queryFulfilled }) {
try {
const { data: imageDTOs } = await queryFulfilled;
// Upsert each DTO into the individual image cache
const updates: Param0<typeof imagesApi.util.upsertQueryEntries> = [];
for (const imageDTO of imageDTOs) {
updates.push({
endpointName: 'getImageDTO',
arg: imageDTO.image_name,
value: imageDTO,
});
}
dispatch(imagesApi.util.upsertQueryEntries(updates));
} catch {
// Handle error if needed
}
},
}),
}),
});
@@ -472,6 +505,7 @@ export const {
useUnstarImagesMutation,
useBulkDownloadImagesMutation,
useGetImageNamesQuery,
useGetImageDTOsByNamesMutation,
} = imagesApi;
/**

View File

@@ -761,7 +761,7 @@ export type paths = {
};
/**
* Get Image Names
* @description Gets ordered list of all image names (starred first, then unstarred)
* @description Gets ordered list of image names with metadata for optimistic updates
*/
get: operations["get_image_names"];
put?: never;
@@ -772,6 +772,26 @@ export type paths = {
patch?: never;
trace?: never;
};
"/api/v1/images/images_by_names": {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
get?: never;
put?: never;
/**
* Get Images By Names
* @description Gets image DTOs for the specified image names. Maintains order of input names.
*/
post: operations["get_images_by_names"];
delete?: never;
options?: never;
head?: never;
patch?: never;
trace?: never;
};
"/api/v1/boards/": {
parameters: {
query?: never;
@@ -2648,6 +2668,14 @@ export type components = {
/** @description The validation run data to use for this batch. This is only used if this is a validation run. */
validation_run_data?: components["schemas"]["ValidationRunData"] | null;
};
/** Body_get_images_by_names */
Body_get_images_by_names: {
/**
* Image Names
* @description Object containing list of image names to fetch DTOs for
*/
image_names: string[];
};
/** Body_import_style_presets */
Body_import_style_presets: {
/**
@@ -10479,6 +10507,27 @@ export type components = {
*/
type: "img_nsfw";
};
/**
* ImageNamesResult
* @description Response containing ordered image names with metadata for optimistic updates.
*/
ImageNamesResult: {
/**
* Image Names
* @description Ordered list of image names
*/
image_names: string[];
/**
* Starred Count
* @description Number of starred images (when starred_first=True)
*/
starred_count: number;
/**
* Total Count
* @description Total number of images matching the query
*/
total_count: number;
};
/**
* Add Image Noise
* @description Add noise to an image
@@ -23725,7 +23774,40 @@ export interface operations {
[name: string]: unknown;
};
content: {
"application/json": string[];
"application/json": components["schemas"]["ImageNamesResult"];
};
};
/** @description Validation Error */
422: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["HTTPValidationError"];
};
};
};
};
get_images_by_names: {
parameters: {
query?: never;
header?: never;
path?: never;
cookie?: never;
};
requestBody: {
content: {
"application/json": components["schemas"]["Body_get_images_by_names"];
};
};
responses: {
/** @description Successful Response */
200: {
headers: {
[name: string]: unknown;
};
content: {
"application/json": components["schemas"]["ImageDTO"][];
};
};
/** @description Validation Error */

View File

@@ -7,6 +7,8 @@ export type S = components['schemas'];
export type ListImagesArgs = NonNullable<paths['/api/v1/images/']['get']['parameters']['query']>;
export type ListImagesResponse = paths['/api/v1/images/']['get']['responses']['200']['content']['application/json'];
export type ImageNamesResult = S['ImageNamesResult'];
export type ListBoardsArgs = NonNullable<paths['/api/v1/boards/']['get']['parameters']['query']>;
export type DeleteBoardResult =

View File

@@ -0,0 +1,90 @@
import type { ImageDTO, ImageNamesResult } from 'services/api/types';
/**
* Calculates the optimal insertion position for a new image in the names list.
* For starred_first=true: starred images go to position 0, unstarred go after all starred images
* For starred_first=false: all new images go to position 0 (newest first)
*/
export function calculateImageInsertionPosition(
imageDTO: ImageDTO,
starredFirst: boolean,
starredCount: number
): number {
if (!starredFirst) {
// When starred_first is false, always insert at the beginning (newest first)
return 0;
}
// When starred_first is true
if (imageDTO.starred) {
// Starred images go at the very beginning
return 0;
}
// Unstarred images go after all starred images
return starredCount;
}
/**
* Optimistically inserts a new image into the ImageNamesResult at the correct position
*/
export function insertImageIntoNamesResult(
currentResult: ImageNamesResult,
imageDTO: ImageDTO,
starredFirst: boolean
): ImageNamesResult {
// Don't insert if the image is already in the list
if (currentResult.image_names.includes(imageDTO.image_name)) {
return currentResult;
}
const insertPosition = calculateImageInsertionPosition(imageDTO, starredFirst, currentResult.starred_count);
const newImageNames = [...currentResult.image_names];
newImageNames.splice(insertPosition, 0, imageDTO.image_name);
return {
image_names: newImageNames,
starred_count: starredFirst && imageDTO.starred ? currentResult.starred_count + 1 : currentResult.starred_count,
total_count: currentResult.total_count + 1,
};
}
/**
* Optimistically removes an image from the ImageNamesResult
*/
export function removeImageFromNamesResult(
currentResult: ImageNamesResult,
imageNameToRemove: string,
wasStarred: boolean,
starredFirst: boolean
): ImageNamesResult {
const newImageNames = currentResult.image_names.filter((name) => name !== imageNameToRemove);
return {
image_names: newImageNames,
starred_count: starredFirst && wasStarred ? currentResult.starred_count - 1 : currentResult.starred_count,
total_count: currentResult.total_count - 1,
};
}
/**
* Optimistically updates an image's position in the result when its starred status changes
*/
export function updateImagePositionInNamesResult(
currentResult: ImageNamesResult,
updatedImageDTO: ImageDTO,
previouslyStarred: boolean,
starredFirst: boolean
): ImageNamesResult {
// First remove the image from its current position
const withoutImage = removeImageFromNamesResult(
currentResult,
updatedImageDTO.image_name,
previouslyStarred,
starredFirst
);
// Then insert it at the new correct position
return insertImageIntoNamesResult(withoutImage, updatedImageDTO, starredFirst);
}

View File

@@ -1,24 +1,22 @@
import { logger } from 'app/logging/logger';
import { addAppListener } from 'app/store/middleware/listenerMiddleware';
import type { AppDispatch, AppGetState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import {
selectAutoSwitch,
selectGalleryView,
selectListImagesBaseQueryArgs,
selectListImageNamesQueryArgs,
selectSelectedBoardId,
} from 'features/gallery/store/gallerySelectors';
import { boardIdSelected, galleryViewChanged, imageSelected } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import type { ApiTagDescription } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories } from 'services/api/util';
import { insertImageIntoNamesResult } from 'services/api/util/optimisticUpdates';
import { $lastProgressEvent } from 'services/events/stores';
import stableHash from 'stable-hash';
import type { Param0 } from 'tsafe';
import { objectEntries } from 'tsafe';
import type { JsonObject } from 'type-fest';
@@ -42,9 +40,7 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
// We'll keep track of each change we need to make and do them all at once.
const boardTotalAdditions: Record<string, number> = {};
const boardTagIdsToInvalidate: Set<string> = new Set();
const imageListTagIdsToInvalidate: Set<string> = new Set();
const listImagesArg = selectListImagesBaseQueryArgs(getState());
const listImageNamesArg = selectListImageNamesQueryArgs(getState());
for (const imageDTO of imageDTOs) {
if (imageDTO.is_intermediate) {
@@ -54,17 +50,6 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
const board_id = imageDTO.board_id ?? 'none';
// update the total images for the board
boardTotalAdditions[board_id] = (boardTotalAdditions[board_id] || 0) + 1;
// invalidate the board tag
boardTagIdsToInvalidate.add(board_id);
// invalidate the image list tag
imageListTagIdsToInvalidate.add(
stableHash({
...listImagesArg,
categories: getCategories(imageDTO),
board_id,
offset: 0,
})
);
}
// Update all the board image totals at once
@@ -85,16 +70,40 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
}
dispatch(boardsApi.util.upsertQueryEntries(entries));
// Invalidate all tags at once
const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({
type: 'Board' as const,
id: boardId,
}));
const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({
type: 'ImageList' as const,
id: imageListId,
}));
dispatch(imagesApi.util.invalidateTags(['ImageNameList', ...boardTags, ...imageListTags]));
// Optimistically update image names lists - DTOs are already cached by getResultImageDTOs
const state = getState();
for (const imageDTO of imageDTOs) {
// Construct the expected query args for this image's getImageNames query
// Use the current gallery query args as base, but override board_id and categories for this specific image
const expectedQueryArgs = {
...listImageNamesArg,
categories: getCategories(imageDTO),
board_id: imageDTO.board_id ?? 'none',
};
// Check if we have cached image names for this query
const cachedNamesResult = imagesApi.endpoints.getImageNames.select(expectedQueryArgs)(state);
if (cachedNamesResult.data) {
// We have cached names - optimistically insert the new image
dispatch(
imagesApi.util.updateQueryData('getImageNames', expectedQueryArgs, (draft) => {
// Use the utility function to insert at the correct position
const updatedResult = insertImageIntoNamesResult(draft, imageDTO, expectedQueryArgs.starred_first ?? true);
// Replace the draft contents
draft.image_names = updatedResult.image_names;
draft.starred_count = updatedResult.starred_count;
draft.total_count = updatedResult.total_count;
})
);
}
// If no cached data, we don't need to do anything - there's no list to update
}
// No need to invalidate tags since we're doing optimistic updates
// Board totals are already updated above via upsertQueryEntries
const autoSwitch = selectAutoSwitch(getState());
@@ -112,63 +121,27 @@ export const buildOnInvocationComplete = (getState: AppGetState, dispatch: AppDi
const { image_name } = lastImageDTO;
const board_id = lastImageDTO.board_id ?? 'none';
/**
* Auto-switch needs a bit of care to avoid race conditions - we need to invalidate the appropriate image list
* query cache, and only after it has loaded, select the new image.
*/
const queryArgs = {
...listImagesArg,
categories: getCategories(lastImageDTO),
board_id,
offset: 0,
};
// With optimistic updates, we can immediately switch to the new image
const selectedBoardId = selectSelectedBoardId(getState());
dispatch(
addAppListener({
predicate: (action) => {
if (!imagesApi.endpoints.listImages.matchFulfilled(action)) {
return false;
}
if (stableHash(action.meta.arg.originalArgs) !== stableHash(queryArgs)) {
return false;
}
return true;
},
effect: (_action, { getState, dispatch, unsubscribe }) => {
// This is a one-time listener - we always unsubscribe after the first match
unsubscribe();
// Auto-switch may have been disabled while we were waiting for the query to resolve - bail if so
const autoSwitch = selectAutoSwitch(getState());
if (!autoSwitch) {
return;
}
const selectedBoardId = selectSelectedBoardId(getState());
// If the image is from a different board, switch to that board & select the image - otherwise just select the
// image. This implicitly changes the view to 'images' if it was not already.
if (board_id !== selectedBoardId) {
dispatch(
boardIdSelected({
boardId: board_id,
selectedImageName: image_name,
})
);
} else {
// Ensure we are on the 'images' gallery view - that's where this image will be displayed
const galleryView = selectGalleryView(getState());
if (galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
// Else just select the image, no need to switch boards
dispatch(imageSelected(lastImageDTO.image_name));
}
},
})
);
// If the image is from a different board, switch to that board & select the image - otherwise just select the
// image. This implicitly changes the view to 'images' if it was not already.
if (board_id !== selectedBoardId) {
dispatch(
boardIdSelected({
boardId: board_id,
selectedImageName: image_name,
})
);
} else {
// Ensure we are on the 'images' gallery view - that's where this image will be displayed
const galleryView = selectGalleryView(getState());
if (galleryView !== 'images') {
dispatch(galleryViewChanged('images'));
}
// Select the image immediately since we've optimistically updated the cache
dispatch(imageSelected(lastImageDTO.image_name));
}
};
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {