mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-06 21:55:11 -05:00
refactor(ui): streamline image field collection input logic, support multiple images w/ same name in collection
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
@@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
addImagesToBoard,
|
||||
addImagesToNodeImageFieldCollectionAction,
|
||||
createNewCanvasEntityFromImage,
|
||||
removeImagesFromBoard,
|
||||
replaceCanvasEntityObjectsWithImage,
|
||||
@@ -19,10 +19,14 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('dnd');
|
||||
|
||||
type RecordUnknown = Record<string | symbol, unknown>;
|
||||
|
||||
type DndData<
|
||||
@@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
|
||||
}
|
||||
|
||||
const { fieldIdentifier } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
|
||||
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
|
||||
} else {
|
||||
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
}
|
||||
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -31,12 +31,17 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
fieldImageValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import {
|
||||
type FieldIdentifier,
|
||||
isStringFieldCollectionInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
@@ -77,52 +82,74 @@ export const setNodeImageFieldImage = (arg: {
|
||||
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
|
||||
};
|
||||
|
||||
export const addImagesToNodeImageFieldCollectionAction = (arg: {
|
||||
imageDTOs: ImageDTO[];
|
||||
export const addStringToNodeStringFieldCollectionAction = (arg: {
|
||||
value: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
|
||||
const { value, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
const uniqueImages = uniqBy(images, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.push(value);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
|
||||
imageName: string;
|
||||
export const removeStringFromNodeStringFieldCollectionAction = (arg: {
|
||||
index: number;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageName, fieldIdentifier, dispatch, getState } = arg;
|
||||
const { index, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove string to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
|
||||
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.splice(index, 1);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const changeStringOnNodeStringFieldCollectionAction = (arg: {
|
||||
index: number;
|
||||
value: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { index, value, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isStringFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add strings to a non-string field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
fieldValue.splice(index, 1, value);
|
||||
dispatch(fieldStringCollectionValueChanged({ ...fieldIdentifier, value: fieldValue }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
|
||||
@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
);
|
||||
|
||||
const onRemoveImage = useCallback(
|
||||
(imageName: string) => {
|
||||
removeImageFromNodeImageFieldCollectionAction({
|
||||
imageName,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
|
||||
{field.value.map(({ image_name }) => (
|
||||
<GridItem key={image_name} position="relative" className="nodrag">
|
||||
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
@@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
|
||||
|
||||
const ImageGridItemContent = memo(
|
||||
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
|
||||
const query = useGetImageDTOQuery(value.image_name);
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveImage(imageName);
|
||||
}, [imageName, onRemoveImage]);
|
||||
onRemoveImage(index);
|
||||
}, [index, onRemoveImage]);
|
||||
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
|
||||
Reference in New Issue
Block a user