mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): add _all_ image outputs to gallery (including collections)
This commit is contained in:
@@ -8,6 +8,11 @@ export const zImageField = z.object({
|
||||
image_name: z.string().trim().min(1),
|
||||
});
|
||||
export type ImageField = z.infer<typeof zImageField>;
|
||||
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
|
||||
const zImageFieldCollection = z.array(zImageField);
|
||||
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
|
||||
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
|
||||
zImageFieldCollection.safeParse(field).success;
|
||||
|
||||
export const zBoardField = z.object({
|
||||
board_id: z.string().trim().min(1),
|
||||
|
||||
@@ -4,13 +4,17 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
|
||||
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
|
||||
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
|
||||
import { zNodeStatus } from 'features/nodes/types/invocation';
|
||||
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils';
|
||||
import type { ApiTagDescription } from 'services/api';
|
||||
import { boardsApi } from 'services/api/endpoints/boards';
|
||||
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO, S } from 'services/api/types';
|
||||
import { getCategories, getListImagesUrl } from 'services/api/util';
|
||||
import { $lastProgressEvent } from 'services/events/stores';
|
||||
import type { Param0 } from 'tsafe';
|
||||
import { objectEntries } from 'tsafe';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('events');
|
||||
@@ -22,58 +26,98 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
|
||||
const nodeTypeDenylist = ['load_image', 'image'];
|
||||
|
||||
export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => {
|
||||
const addImageToGallery = (data: S['InvocationCompleteEvent'], imageDTO: ImageDTO) => {
|
||||
const addImagesToGallery = (data: S['InvocationCompleteEvent'], imageDTOs: ImageDTO[]) => {
|
||||
if (nodeTypeDenylist.includes(data.invocation.type)) {
|
||||
log.trace('Skipping node type denylisted');
|
||||
log.trace(`Skipping denylisted node type (${data.invocation.type})`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (imageDTO.is_intermediate) {
|
||||
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
|
||||
// We'll keep track of each change we need to make and do them all at once.
|
||||
const boardTotalAdditions: Record<string, number> = {};
|
||||
const boardTagIdsToInvalidate: Set<string> = new Set();
|
||||
const imageListTagIdsToInvalidate: Set<string> = new Set();
|
||||
|
||||
for (const imageDTO of imageDTOs) {
|
||||
if (imageDTO.is_intermediate) {
|
||||
return;
|
||||
}
|
||||
|
||||
const boardId = imageDTO.board_id ?? 'none';
|
||||
// update the total images for the board
|
||||
boardTotalAdditions[boardId] = (boardTotalAdditions[boardId] || 0) + 1;
|
||||
// invalidate the board tag
|
||||
boardTagIdsToInvalidate.add(boardId);
|
||||
// invalidate the image list tag
|
||||
imageListTagIdsToInvalidate.add(
|
||||
getListImagesUrl({
|
||||
board_id: boardId,
|
||||
categories: getCategories(imageDTO),
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
// Update all the board image totals at once
|
||||
const entries: Param0<typeof boardsApi.util.upsertQueryEntries> = [];
|
||||
for (const [boardId, amountToAdd] of objectEntries(boardTotalAdditions)) {
|
||||
// upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
|
||||
// directly. So we need to select the board totals first.
|
||||
const total = boardsApi.endpoints.getBoardImagesTotal.select(boardId)(getState()).data?.total;
|
||||
if (total === undefined) {
|
||||
// No cache exists for this board, so we can't update it.
|
||||
continue;
|
||||
}
|
||||
entries.push({
|
||||
endpointName: 'getBoardImagesTotal',
|
||||
arg: boardId,
|
||||
value: { total: total + amountToAdd },
|
||||
});
|
||||
}
|
||||
dispatch(boardsApi.util.upsertQueryEntries(entries));
|
||||
|
||||
// Invalidate all tags at once
|
||||
const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({
|
||||
type: 'Board' as const,
|
||||
id: boardId,
|
||||
}));
|
||||
const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({
|
||||
type: 'ImageList' as const,
|
||||
id: imageListId,
|
||||
}));
|
||||
dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags]));
|
||||
|
||||
// Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
|
||||
|
||||
const lastImageDTO = imageDTOs.at(-1);
|
||||
|
||||
if (!lastImageDTO) {
|
||||
return;
|
||||
}
|
||||
|
||||
// update the total images for the board
|
||||
dispatch(
|
||||
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
|
||||
draft.total += 1;
|
||||
})
|
||||
);
|
||||
|
||||
dispatch(
|
||||
imagesApi.util.invalidateTags([
|
||||
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
|
||||
{
|
||||
type: 'ImageList',
|
||||
id: getListImagesUrl({
|
||||
board_id: imageDTO.board_id ?? 'none',
|
||||
categories: getCategories(imageDTO),
|
||||
}),
|
||||
},
|
||||
])
|
||||
);
|
||||
const { image_name, board_id } = lastImageDTO;
|
||||
|
||||
const { shouldAutoSwitch, selectedBoardId, galleryView, offset } = getState().gallery;
|
||||
|
||||
// If auto-switch is enabled, select the new image
|
||||
if (shouldAutoSwitch) {
|
||||
// If the image is from a different board, switch to that board - this will also select the image
|
||||
if (imageDTO.board_id && imageDTO.board_id !== selectedBoardId) {
|
||||
if (board_id && board_id !== selectedBoardId) {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: imageDTO.board_id,
|
||||
selectedImageName: imageDTO.image_name,
|
||||
boardId: board_id,
|
||||
selectedImageName: image_name,
|
||||
})
|
||||
);
|
||||
} else if (!imageDTO.board_id && selectedBoardId !== 'none') {
|
||||
} else if (!board_id && selectedBoardId !== 'none') {
|
||||
dispatch(
|
||||
boardIdSelected({
|
||||
boardId: 'none',
|
||||
selectedImageName: imageDTO.image_name,
|
||||
selectedImageName: image_name,
|
||||
})
|
||||
);
|
||||
} else {
|
||||
// Else just select the image, no need to switch boards
|
||||
dispatch(imageSelected(imageDTO));
|
||||
dispatch(imageSelected(lastImageDTO));
|
||||
|
||||
if (galleryView !== 'images') {
|
||||
// We also need to update the gallery view to images. This also updates the offset.
|
||||
@@ -86,12 +130,25 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
||||
}
|
||||
};
|
||||
|
||||
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
|
||||
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {
|
||||
const { result } = data;
|
||||
if (result.type === 'image_output') {
|
||||
return getImageDTOSafe(result.image.image_name);
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
for (const [_name, value] of objectEntries(result)) {
|
||||
if (isImageField(value)) {
|
||||
const imageDTO = await getImageDTOSafe(value.image_name);
|
||||
if (imageDTO) {
|
||||
imageDTOs.push(imageDTO);
|
||||
}
|
||||
} else if (isImageFieldCollection(value)) {
|
||||
for (const imageField of value) {
|
||||
const imageDTO = await getImageDTOSafe(imageField.image_name);
|
||||
if (imageDTO) {
|
||||
imageDTOs.push(imageDTO);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
return imageDTOs;
|
||||
};
|
||||
|
||||
const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => {
|
||||
@@ -107,16 +164,15 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
||||
upsertExecutionState(nes.nodeId, nes);
|
||||
}
|
||||
|
||||
const imageDTO = await getResultImageDTO(data);
|
||||
|
||||
if (imageDTO && !imageDTO.is_intermediate) {
|
||||
addImageToGallery(data, imageDTO);
|
||||
}
|
||||
const imageDTOs = await getResultImageDTOs(data);
|
||||
addImagesToGallery(data, imageDTOs);
|
||||
};
|
||||
|
||||
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
|
||||
const imageDTO = await getResultImageDTO(data);
|
||||
const imageDTOs = await getResultImageDTOs(data);
|
||||
|
||||
// We expect only a single image in the canvas output
|
||||
const imageDTO = imageDTOs[0];
|
||||
if (!imageDTO) {
|
||||
return;
|
||||
}
|
||||
@@ -127,20 +183,17 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
|
||||
if (data.result.type === 'image_output') {
|
||||
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
|
||||
}
|
||||
addImageToGallery(data, imageDTO);
|
||||
addImagesToGallery(data, [imageDTO]);
|
||||
}
|
||||
} else if (!imageDTO.is_intermediate) {
|
||||
// Desintaion is gallery
|
||||
addImageToGallery(data, imageDTO);
|
||||
addImagesToGallery(data, [imageDTO]);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
|
||||
const imageDTO = await getResultImageDTO(data);
|
||||
|
||||
if (imageDTO && !imageDTO.is_intermediate) {
|
||||
addImageToGallery(data, imageDTO);
|
||||
}
|
||||
const imageDTOs = await getResultImageDTOs(data);
|
||||
addImagesToGallery(data, imageDTOs);
|
||||
};
|
||||
|
||||
return async (data: S['InvocationCompleteEvent']) => {
|
||||
|
||||
Reference in New Issue
Block a user