feat(ui): add _all_ image outputs to gallery (including collections)

This commit is contained in:
psychedelicious
2025-04-25 14:22:40 +10:00
parent c768f47365
commit 3988128c40
2 changed files with 103 additions and 45 deletions

View File

@@ -8,6 +8,11 @@ export const zImageField = z.object({
image_name: z.string().trim().min(1),
});
export type ImageField = z.infer<typeof zImageField>;
export const isImageField = (field: unknown): field is ImageField => zImageField.safeParse(field).success;
const zImageFieldCollection = z.array(zImageField);
type ImageFieldCollection = z.infer<typeof zImageFieldCollection>;
export const isImageFieldCollection = (field: unknown): field is ImageFieldCollection =>
zImageFieldCollection.safeParse(field).success;
export const zBoardField = z.object({
board_id: z.string().trim().min(1),

View File

@@ -4,13 +4,17 @@ import { deepClone } from 'common/util/deepClone';
import { stagingAreaImageStaged } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { boardIdSelected, galleryViewChanged, imageSelected, offsetChanged } from 'features/gallery/store/gallerySlice';
import { $nodeExecutionStates, upsertExecutionState } from 'features/nodes/hooks/useNodeExecutionState';
import { isImageField, isImageFieldCollection } from 'features/nodes/types/common';
import { zNodeStatus } from 'features/nodes/types/invocation';
import { CANVAS_OUTPUT_PREFIX } from 'features/nodes/util/graph/graphBuilderUtils';
import type { ApiTagDescription } from 'services/api';
import { boardsApi } from 'services/api/endpoints/boards';
import { getImageDTOSafe, imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO, S } from 'services/api/types';
import { getCategories, getListImagesUrl } from 'services/api/util';
import { $lastProgressEvent } from 'services/events/stores';
import type { Param0 } from 'tsafe';
import { objectEntries } from 'tsafe';
import type { JsonObject } from 'type-fest';
const log = logger('events');
@@ -22,58 +26,98 @@ const isCanvasOutputNode = (data: S['InvocationCompleteEvent']) => {
const nodeTypeDenylist = ['load_image', 'image'];
export const buildOnInvocationComplete = (getState: () => RootState, dispatch: AppDispatch) => {
const addImageToGallery = (data: S['InvocationCompleteEvent'], imageDTO: ImageDTO) => {
const addImagesToGallery = (data: S['InvocationCompleteEvent'], imageDTOs: ImageDTO[]) => {
if (nodeTypeDenylist.includes(data.invocation.type)) {
log.trace('Skipping node type denylisted');
log.trace(`Skipping denylisted node type (${data.invocation.type})`);
return;
}
if (imageDTO.is_intermediate) {
// For efficiency's sake, we want to minimize the number of dispatches and invalidations we do.
// We'll keep track of each change we need to make and do them all at once.
const boardTotalAdditions: Record<string, number> = {};
const boardTagIdsToInvalidate: Set<string> = new Set();
const imageListTagIdsToInvalidate: Set<string> = new Set();
for (const imageDTO of imageDTOs) {
if (imageDTO.is_intermediate) {
return;
}
const boardId = imageDTO.board_id ?? 'none';
// update the total images for the board
boardTotalAdditions[boardId] = (boardTotalAdditions[boardId] || 0) + 1;
// invalidate the board tag
boardTagIdsToInvalidate.add(boardId);
// invalidate the image list tag
imageListTagIdsToInvalidate.add(
getListImagesUrl({
board_id: boardId,
categories: getCategories(imageDTO),
})
);
}
// Update all the board image totals at once
const entries: Param0<typeof boardsApi.util.upsertQueryEntries> = [];
for (const [boardId, amountToAdd] of objectEntries(boardTotalAdditions)) {
// upsertQueryEntries doesn't provide a "recipe" function for the update - we must provide the new value
// directly. So we need to select the board totals first.
const total = boardsApi.endpoints.getBoardImagesTotal.select(boardId)(getState()).data?.total;
if (total === undefined) {
// No cache exists for this board, so we can't update it.
continue;
}
entries.push({
endpointName: 'getBoardImagesTotal',
arg: boardId,
value: { total: total + amountToAdd },
});
}
dispatch(boardsApi.util.upsertQueryEntries(entries));
// Invalidate all tags at once
const boardTags: ApiTagDescription[] = Array.from(boardTagIdsToInvalidate).map((boardId) => ({
type: 'Board' as const,
id: boardId,
}));
const imageListTags: ApiTagDescription[] = Array.from(imageListTagIdsToInvalidate).map((imageListId) => ({
type: 'ImageList' as const,
id: imageListId,
}));
dispatch(imagesApi.util.invalidateTags([...boardTags, ...imageListTags]));
// Finally, we may need to autoswitch to the new image. We'll only do it for the last image in the list.
const lastImageDTO = imageDTOs.at(-1);
if (!lastImageDTO) {
return;
}
// update the total images for the board
dispatch(
boardsApi.util.updateQueryData('getBoardImagesTotal', imageDTO.board_id ?? 'none', (draft) => {
draft.total += 1;
})
);
dispatch(
imagesApi.util.invalidateTags([
{ type: 'Board', id: imageDTO.board_id ?? 'none' },
{
type: 'ImageList',
id: getListImagesUrl({
board_id: imageDTO.board_id ?? 'none',
categories: getCategories(imageDTO),
}),
},
])
);
const { image_name, board_id } = lastImageDTO;
const { shouldAutoSwitch, selectedBoardId, galleryView, offset } = getState().gallery;
// If auto-switch is enabled, select the new image
if (shouldAutoSwitch) {
// If the image is from a different board, switch to that board - this will also select the image
if (imageDTO.board_id && imageDTO.board_id !== selectedBoardId) {
if (board_id && board_id !== selectedBoardId) {
dispatch(
boardIdSelected({
boardId: imageDTO.board_id,
selectedImageName: imageDTO.image_name,
boardId: board_id,
selectedImageName: image_name,
})
);
} else if (!imageDTO.board_id && selectedBoardId !== 'none') {
} else if (!board_id && selectedBoardId !== 'none') {
dispatch(
boardIdSelected({
boardId: 'none',
selectedImageName: imageDTO.image_name,
selectedImageName: image_name,
})
);
} else {
// Else just select the image, no need to switch boards
dispatch(imageSelected(imageDTO));
dispatch(imageSelected(lastImageDTO));
if (galleryView !== 'images') {
// We also need to update the gallery view to images. This also updates the offset.
@@ -86,12 +130,25 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
}
};
const getResultImageDTO = (data: S['InvocationCompleteEvent']) => {
const getResultImageDTOs = async (data: S['InvocationCompleteEvent']): Promise<ImageDTO[]> => {
const { result } = data;
if (result.type === 'image_output') {
return getImageDTOSafe(result.image.image_name);
const imageDTOs: ImageDTO[] = [];
for (const [_name, value] of objectEntries(result)) {
if (isImageField(value)) {
const imageDTO = await getImageDTOSafe(value.image_name);
if (imageDTO) {
imageDTOs.push(imageDTO);
}
} else if (isImageFieldCollection(value)) {
for (const imageField of value) {
const imageDTO = await getImageDTOSafe(imageField.image_name);
if (imageDTO) {
imageDTOs.push(imageDTO);
}
}
}
}
return null;
return imageDTOs;
};
const handleOriginWorkflows = async (data: S['InvocationCompleteEvent']) => {
@@ -107,16 +164,15 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
upsertExecutionState(nes.nodeId, nes);
}
const imageDTO = await getResultImageDTO(data);
if (imageDTO && !imageDTO.is_intermediate) {
addImageToGallery(data, imageDTO);
}
const imageDTOs = await getResultImageDTOs(data);
addImagesToGallery(data, imageDTOs);
};
const handleOriginCanvas = async (data: S['InvocationCompleteEvent']) => {
const imageDTO = await getResultImageDTO(data);
const imageDTOs = await getResultImageDTOs(data);
// We expect only a single image in the canvas output
const imageDTO = imageDTOs[0];
if (!imageDTO) {
return;
}
@@ -127,20 +183,17 @@ export const buildOnInvocationComplete = (getState: () => RootState, dispatch: A
if (data.result.type === 'image_output') {
dispatch(stagingAreaImageStaged({ stagingAreaImage: { imageDTO, offsetX: 0, offsetY: 0 } }));
}
addImageToGallery(data, imageDTO);
addImagesToGallery(data, [imageDTO]);
}
} else if (!imageDTO.is_intermediate) {
// Desintaion is gallery
addImageToGallery(data, imageDTO);
addImagesToGallery(data, [imageDTO]);
}
};
const handleOriginOther = async (data: S['InvocationCompleteEvent']) => {
const imageDTO = await getResultImageDTO(data);
if (imageDTO && !imageDTO.is_intermediate) {
addImageToGallery(data, imageDTO);
}
const imageDTOs = await getResultImageDTOs(data);
addImagesToGallery(data, imageDTOs);
};
return async (data: S['InvocationCompleteEvent']) => {