From e077fe8046cdb30a47124641cf24aca794c65033 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 10 Jan 2025 10:27:20 +1000 Subject: [PATCH] refactor(ui): streamline image field collection input logic, support multiple images w/ same name in collection --- invokeai/frontend/web/src/features/dnd/dnd.ts | 32 ++++++--- .../web/src/features/imageActions/actions.ts | 69 +++++++++++++------ .../ImageFieldCollectionInputComponent.tsx | 29 ++++---- 3 files changed, 85 insertions(+), 45 deletions(-) diff --git a/invokeai/frontend/web/src/features/dnd/dnd.ts b/invokeai/frontend/web/src/features/dnd/dnd.ts index 9b8972e3c4..026a2aa0db 100644 --- a/invokeai/frontend/web/src/features/dnd/dnd.ts +++ b/invokeai/frontend/web/src/features/dnd/dnd.ts @@ -1,3 +1,4 @@ +import { logger } from 'app/logging/logger'; import type { AppDispatch, RootState } from 'app/store/store'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import type { @@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/ import type { BoardId } from 'features/gallery/store/types'; import { addImagesToBoard, - addImagesToNodeImageFieldCollectionAction, createNewCanvasEntityFromImage, removeImagesFromBoard, replaceCanvasEntityObjectsWithImage, @@ -19,10 +19,14 @@ import { setRegionalGuidanceReferenceImage, setUpscaleInitialImage, } from 'features/imageActions/actions'; -import type { FieldIdentifier } from 'features/nodes/types/field'; +import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice'; +import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors'; +import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field'; import type { ImageDTO } from 'services/api/types'; import type { JsonObject } from 'type-fest'; +const log = logger('dnd'); + type RecordUnknown = Record; type DndData< @@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget< } const { fieldIdentifier } = targetData.payload; - const imageDTOs: ImageDTO[] = []; - if (singleImageDndSource.typeGuard(sourceData)) { - imageDTOs.push(sourceData.payload.imageDTO); - } else { - imageDTOs.push(...sourceData.payload.imageDTOs); + const fieldInputInstance = selectFieldInputInstance( + selectNodesSlice(getState()), + fieldIdentifier.nodeId, + fieldIdentifier.fieldName + ); + + if (!isImageFieldCollectionInputInstance(fieldInputInstance)) { + log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection'); + return; } - addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState }); + const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : []; + + if (singleImageDndSource.typeGuard(sourceData)) { + newValue.push({ image_name: sourceData.payload.imageDTO.image_name }); + } else { + newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name }))); + } + + dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue })); }, }; //#endregion diff --git a/invokeai/frontend/web/src/features/imageActions/actions.ts b/invokeai/frontend/web/src/features/imageActions/actions.ts index a6f09377c5..0cc49ce76e 100644 --- a/invokeai/frontend/web/src/features/imageActions/actions.ts +++ b/invokeai/frontend/web/src/features/imageActions/actions.ts @@ -31,12 +31,17 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions'; import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice'; import type { BoardId } from 'features/gallery/store/types'; -import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice'; +import { + fieldImageValueChanged, + fieldStringCollectionValueChanged, +} from 'features/nodes/store/nodesSlice'; import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors'; -import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field'; +import { + type FieldIdentifier, + isStringFieldCollectionInputInstance, +} from 'features/nodes/types/field'; import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice'; import { getOptimalDimension } from 'features/parameters/util/optimalDimension'; -import { uniqBy } from 'lodash-es'; import { imagesApi } from 'services/api/endpoints/images'; import type { ImageDTO } from 'services/api/types'; import type { Equals } from 'tsafe'; @@ -77,52 +82,74 @@ export const setNodeImageFieldImage = (arg: { dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO })); }; -export const addImagesToNodeImageFieldCollectionAction = (arg: { - imageDTOs: ImageDTO[]; +export const addStringToNodeStringFieldCollectionAction = (arg: { + value: string; fieldIdentifier: FieldIdentifier; dispatch: AppDispatch; getState: () => RootState; }) => { - const { imageDTOs, fieldIdentifier, dispatch, getState } = arg; + const { value, fieldIdentifier, dispatch, getState } = arg; const fieldInputInstance = selectFieldInputInstance( selectNodesSlice(getState()), fieldIdentifier.nodeId, fieldIdentifier.fieldName ); - if (!isImageFieldCollectionInputInstance(fieldInputInstance)) { - log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection'); + if (!isStringFieldCollectionInputInstance(fieldInputInstance)) { + log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection'); return; } - const images = fieldInputInstance.value ? [...fieldInputInstance.value] : []; - images.push(...imageDTOs.map(({ image_name }) => ({ image_name }))); - const uniqueImages = uniqBy(images, 'image_name'); - dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages })); + const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : []; + fieldValue.push(value); + dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue })); }; -export const removeImageFromNodeImageFieldCollectionAction = (arg: { - imageName: string; +export const removeStringFromNodeStringFieldCollectionAction = (arg: { + index: number; fieldIdentifier: FieldIdentifier; dispatch: AppDispatch; getState: () => RootState; }) => { - const { imageName, fieldIdentifier, dispatch, getState } = arg; + const { index, fieldIdentifier, dispatch, getState } = arg; const fieldInputInstance = selectFieldInputInstance( selectNodesSlice(getState()), fieldIdentifier.nodeId, fieldIdentifier.fieldName ); - if (!isImageFieldCollectionInputInstance(fieldInputInstance)) { - log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection'); + if (!isStringFieldCollectionInputInstance(fieldInputInstance)) { + log.warn({ fieldIdentifier }, 'Attempted to remove string to a non-string field collection'); return; } - const images = fieldInputInstance.value ? [...fieldInputInstance.value] : []; - const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName); - const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name'); - dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages })); + const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : []; + fieldValue.splice(index, 1); + dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue })); +}; + +export const changeStringOnNodeStringFieldCollectionAction = (arg: { + index: number; + value: string; + fieldIdentifier: FieldIdentifier; + dispatch: AppDispatch; + getState: () => RootState; +}) => { + const { index, value, fieldIdentifier, dispatch, getState } = arg; + const fieldInputInstance = selectFieldInputInstance( + selectNodesSlice(getState()), + fieldIdentifier.nodeId, + fieldIdentifier.fieldName + ); + + if (!isStringFieldCollectionInputInstance(fieldInputInstance)) { + log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection'); + return; + } + + const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : []; + fieldValue.splice(index, 1, value); + dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue })); }; export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent.tsx index aa8b537e92..d011a7ea76 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent.tsx @@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd'; import { DndDropTarget } from 'features/dnd/DndDropTarget'; import { DndImage } from 'features/dnd/DndImage'; import { DndImageIcon } from 'features/dnd/DndImageIcon'; -import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions'; import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid'; import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice'; +import type { ImageField } from 'features/nodes/types/common'; import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field'; import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; import { memo, useCallback, useMemo } from 'react'; @@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo( ); const onRemoveImage = useCallback( - (imageName: string) => { - removeImageFromNodeImageFieldCollectionAction({ - imageName, - fieldIdentifier: { nodeId, fieldName: field.name }, - dispatch: store.dispatch, - getState: store.getState, - }); + (index: number) => { + const newValue = field.value ? [...field.value] : []; + newValue.splice(index, 1); + store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue })); }, - [field.name, nodeId, store.dispatch, store.getState] + [field.name, field.value, nodeId, store] ); return ( @@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo( options={overlayscrollbarsOptions} > - {field.value.map(({ image_name }) => ( - - + {field.value.map((value, index) => ( + + ))} @@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo( ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent'; const ImageGridItemContent = memo( - ({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => { - const query = useGetImageDTOQuery(imageName); + ({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => { + const query = useGetImageDTOQuery(value.image_name); const onClickRemove = useCallback(() => { - onRemoveImage(imageName); - }, [imageName, onRemoveImage]); + onRemoveImage(index); + }, [index, onRemoveImage]); if (query.isLoading) { return ;