mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): image batching in workflows
- Add special handling for `ImageBatchInvocation` - Add input component for image collections, supporting multi-image upload and dnd - Minor rework of some hooks for accessing node data
This commit is contained in:
@@ -1,10 +1,15 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -26,6 +31,33 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
delete builtWorkflow.id;
|
||||
}
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
|
||||
for (const node of imageBatchNodes) {
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromImageBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
batchDataCollectionItem.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: images.value,
|
||||
});
|
||||
}
|
||||
if (batchDataCollectionItem.length > 0) {
|
||||
data.push(batchDataCollectionItem);
|
||||
}
|
||||
}
|
||||
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
graph,
|
||||
@@ -33,6 +65,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
@@ -141,11 +141,9 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
|
||||
};
|
||||
|
||||
const sx = {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
borderWidth: 0,
|
||||
borderRadius: 'base',
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
borderWidth: 1,
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
@@ -164,7 +162,34 @@ export const UploadImageButton = ({
|
||||
<>
|
||||
<IconButton
|
||||
aria-label="Upload image"
|
||||
variant="ghost"
|
||||
variant="outline"
|
||||
sx={sx}
|
||||
data-error={isError}
|
||||
icon={<PiUploadBold />}
|
||||
isLoading={uploadApi.request.isLoading}
|
||||
{...rest}
|
||||
{...uploadApi.getUploadButtonProps()}
|
||||
/>
|
||||
<input {...uploadApi.getUploadInputProps()} />
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export const UploadMultipleImageButton = ({
|
||||
isDisabled = false,
|
||||
onUpload,
|
||||
isError = false,
|
||||
...rest
|
||||
}: {
|
||||
onUpload?: (imageDTOs: ImageDTO[]) => void;
|
||||
isError?: boolean;
|
||||
} & SetOptional<IconButtonProps, 'aria-label'>) => {
|
||||
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: true, onUpload });
|
||||
return (
|
||||
<>
|
||||
<IconButton
|
||||
aria-label="Upload image"
|
||||
variant="outline"
|
||||
sx={sx}
|
||||
data-error={isError}
|
||||
icon={<PiUploadBold />}
|
||||
|
||||
@@ -22,12 +22,15 @@ const sx = {
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
type Props = ImageProps & {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
};
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
export namespace DndImage {
|
||||
export interface Props extends ImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
|
||||
const store = useAppStore();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [element, ref] = useState<HTMLImageElement | null>(null);
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { memo } from 'react';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
namespace DndImageFromImageName {
|
||||
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
|
||||
imageName: string;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
}
|
||||
if (!query.data) {
|
||||
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
|
||||
}
|
||||
|
||||
return <DndImage imageDTO={query.data} {...rest} />;
|
||||
});
|
||||
|
||||
DndImageFromImageName.displayName = 'DndImageFromImageName';
|
||||
@@ -9,6 +9,7 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
addImagesToBoard,
|
||||
addImagesToNodeImageFieldCollectionAction,
|
||||
createNewCanvasEntityFromImage,
|
||||
removeImagesFromBoard,
|
||||
replaceCanvasEntityObjectsWithImage,
|
||||
@@ -241,6 +242,45 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region Add Images to Image Collection Node Field
|
||||
const _addImagesToNodeImageFieldCollection = buildTypeAndKey('add-images-to-image-collection-node-field');
|
||||
export type AddImagesToNodeImageFieldCollection = DndData<
|
||||
typeof _addImagesToNodeImageFieldCollection.type,
|
||||
typeof _addImagesToNodeImageFieldCollection.key,
|
||||
{ fieldIdentifer: FieldIdentifier }
|
||||
>;
|
||||
export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
|
||||
AddImagesToNodeImageFieldCollection,
|
||||
SingleImageDndSourceData | MultipleImageDndSourceData
|
||||
> = {
|
||||
..._addImagesToNodeImageFieldCollection,
|
||||
typeGuard: buildTypeGuard(_addImagesToNodeImageFieldCollection.key),
|
||||
getData: buildGetData(_addImagesToNodeImageFieldCollection.key, _addImagesToNodeImageFieldCollection.type),
|
||||
isValid: ({ sourceData }) => {
|
||||
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
handler: ({ sourceData, targetData, dispatch }) => {
|
||||
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { fieldIdentifer } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
}
|
||||
|
||||
addImagesToNodeImageFieldCollectionAction({ fieldIdentifer, imageDTOs, dispatch });
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//# Set Comparison Image
|
||||
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
|
||||
export type SetComparisonImageDndTargetData = DndData<
|
||||
@@ -430,6 +470,7 @@ export const dndTargets = [
|
||||
// Single or Multiple Image
|
||||
addImageToBoardDndTarget,
|
||||
removeImageFromBoardDndTarget,
|
||||
addImagesToNodeImageFieldCollectionDndTarget,
|
||||
] as const;
|
||||
|
||||
export type AnyDndTarget = (typeof dndTargets)[number];
|
||||
|
||||
@@ -29,7 +29,7 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
@@ -71,6 +71,15 @@ export const setNodeImageFieldImage = (arg: {
|
||||
dispatch(fieldImageValueChanged({ ...fieldIdentifer, value: imageDTO }));
|
||||
};
|
||||
|
||||
export const addImagesToNodeImageFieldCollectionAction = (arg: {
|
||||
imageDTOs: ImageDTO[];
|
||||
fieldIdentifer: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
}) => {
|
||||
const { imageDTOs, fieldIdentifer, dispatch } = arg;
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifer, value: imageDTOs }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
|
||||
@@ -40,7 +40,7 @@ import { computed } from 'nanostores';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
|
||||
import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi';
|
||||
import type { EdgeChange, NodeChange } from 'reactflow';
|
||||
import type { S } from 'services/api/types';
|
||||
|
||||
@@ -403,7 +403,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}
|
||||
|
||||
return _items;
|
||||
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
|
||||
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]);
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -414,6 +414,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
|
||||
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
|
||||
{item.classification === 'internal' && <Icon boxSize={4} color="invokePurple.300" as={PiCircuitryBold} />}
|
||||
{item.classification === 'special' && <Icon boxSize={4} color="invokeGreen.300" as={PiLightningFill} />}
|
||||
<Text fontWeight="semibold">{item.label}</Text>
|
||||
<Spacer />
|
||||
<Text variant="subtext" fontWeight="semibold">
|
||||
|
||||
@@ -3,7 +3,7 @@ import { useNodeClassification } from 'features/nodes/hooks/useNodeClassificatio
|
||||
import type { Classification } from 'features/nodes/types/common';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
|
||||
import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi';
|
||||
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
@@ -62,5 +62,9 @@ const ClassificationIcon = ({ classification }: { classification: Classification
|
||||
return <Icon as={PiCircuitryBold} display="block" boxSize={4} color="invokePurple.300" />;
|
||||
}
|
||||
|
||||
if (classification === 'special') {
|
||||
return <Icon as={PiLightningFill} display="block" boxSize={4} color="invokeGreen.300" />;
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiInfoBold } from 'react-icons/pi';
|
||||
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
||||
export default memo(InvocationNodeInfoIcon);
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
if (data?.label && nodeTemplate?.title) {
|
||||
return `${data.label} (${nodeTemplate.title})`;
|
||||
if (node.data?.label && nodeTemplate?.title) {
|
||||
return `${node.data.label} (${nodeTemplate.title})`;
|
||||
}
|
||||
|
||||
if (data?.label && !nodeTemplate) {
|
||||
return data.label;
|
||||
if (node.data?.label && !nodeTemplate) {
|
||||
return node.data.label;
|
||||
}
|
||||
|
||||
if (!data?.label && nodeTemplate) {
|
||||
if (!node.data?.label && nodeTemplate) {
|
||||
return nodeTemplate.title;
|
||||
}
|
||||
|
||||
return t('nodes.unknownNode');
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node.data.label, nodeTemplate, t]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
if (!isInvocationNode(node) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
if (!node.data.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.versionUnknown')}
|
||||
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Text as="span">
|
||||
{t('nodes.version')} {data.version}
|
||||
{t('nodes.version')} {node.data.version}
|
||||
</Text>
|
||||
);
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node, nodeTemplate, t]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
{node.data?.notes && <Text>{node.data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl orientation="vertical" h="full">
|
||||
<FormLabel>{t('nodes.notes')}</FormLabel>
|
||||
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
@@ -24,6 +25,8 @@ import {
|
||||
isFluxMainModelFieldInputTemplate,
|
||||
isFluxVAEModelFieldInputInstance,
|
||||
isFluxVAEModelFieldInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
@@ -110,6 +113,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,99 @@
|
||||
import { Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { UploadMultipleImageButton } from 'common/hooks/useImageUploadButton';
|
||||
import type { AddImagesToNodeImageFieldCollection } from 'features/dnd/dnd';
|
||||
import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
export const ImageFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate>) => {
|
||||
const { t } = useTranslation();
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const dispatch = useAppDispatch();
|
||||
const onReset = useCallback(() => {
|
||||
dispatch(
|
||||
fieldImageCollectionValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: [],
|
||||
})
|
||||
);
|
||||
}, [dispatch, field.name, nodeId]);
|
||||
|
||||
const dndTargetData = useMemo<AddImagesToNodeImageFieldCollection>(
|
||||
() => addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifer: { nodeId, fieldName: field.name } }),
|
||||
[field, nodeId]
|
||||
);
|
||||
|
||||
const onUpload = useCallback(
|
||||
(imageDTOs: ImageDTO[]) => {
|
||||
dispatch(
|
||||
fieldImageCollectionValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: imageDTOs,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="relative"
|
||||
className="nodrag"
|
||||
w="full"
|
||||
h="full"
|
||||
minH={16}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
>
|
||||
{field.value.length === 0 && (
|
||||
<UploadMultipleImageButton
|
||||
w="full"
|
||||
h="auto"
|
||||
isError={fieldTemplate.required}
|
||||
onUpload={onUpload}
|
||||
fontSize={24}
|
||||
variant="outline"
|
||||
/>
|
||||
)}
|
||||
{field.value.length > 0 && (
|
||||
<>
|
||||
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
|
||||
{field.value.map(({ image_name }) => (
|
||||
<GridItem key={image_name}>
|
||||
<DndImageFromImageName imageName={image_name} asThumbnail />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
<IconButton
|
||||
aria-label="reset"
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineEnd={0}
|
||||
onClick={onReset}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
<DndDropTarget
|
||||
dndTarget={addImagesToNodeImageFieldCollectionDndTarget}
|
||||
dndTargetData={dndTargetData}
|
||||
label={t('gallery.drop')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
|
||||
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
import type { Node } from 'reactflow';
|
||||
|
||||
export const useNode = (nodeId: string): Node => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNode(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const node = useAppSelector(selector);
|
||||
|
||||
return node;
|
||||
};
|
||||
@@ -16,6 +16,7 @@ import type {
|
||||
FieldValue,
|
||||
FloatFieldValue,
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
@@ -42,6 +43,7 @@ import {
|
||||
zEnumFieldValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
@@ -319,6 +321,9 @@ export const nodesSlice = createSlice({
|
||||
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
|
||||
fieldValueReducer(state, action, zImageFieldValue);
|
||||
},
|
||||
fieldImageCollectionValueChanged: (state, action: FieldValueAction<ImageFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zImageFieldCollectionValue);
|
||||
},
|
||||
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zColorFieldValue);
|
||||
},
|
||||
@@ -416,6 +421,7 @@ export const {
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldImageCollectionValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldSpandrelImageToImageModelValueChanged,
|
||||
@@ -527,6 +533,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldImageValueChanged,
|
||||
fieldImageCollectionValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldT2IAdapterModelValueChanged,
|
||||
fieldLabelChanged,
|
||||
|
||||
@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Node } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(node !== undefined, `Node ${nodeId} not found`);
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||
|
||||
@@ -22,7 +22,7 @@ export const zColorField = z.object({
|
||||
});
|
||||
export type ColorField = z.infer<typeof zColorField>;
|
||||
|
||||
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated', 'internal']);
|
||||
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated', 'internal', 'special']);
|
||||
export type Classification = z.infer<typeof zClassification>;
|
||||
|
||||
export const zSchedulerField = z.enum([
|
||||
|
||||
@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
ImageBatchField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
|
||||
@@ -95,6 +95,11 @@ const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImageCollectionFieldType = z.object({
|
||||
name: z.literal('ImageField'),
|
||||
cardinality: z.literal(zCardinality.Values.COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -347,7 +352,6 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
|
||||
export const zImageFieldValue = zImageField.optional();
|
||||
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImageFieldValue,
|
||||
@@ -369,6 +373,28 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region ImageField Collection
|
||||
export const zImageFieldCollectionValue = z.array(zImageField);
|
||||
const zImageFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zImageFieldCollectionValue,
|
||||
});
|
||||
const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zImageCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zImageFieldCollectionValue,
|
||||
});
|
||||
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageCollectionFieldType,
|
||||
});
|
||||
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
|
||||
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
|
||||
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
|
||||
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
|
||||
zImageFieldCollectionInputInstance.safeParse(val).success;
|
||||
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
|
||||
zImageFieldCollectionInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
|
||||
export const zBoardFieldValue = zBoardField.optional();
|
||||
@@ -885,6 +911,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zBoardFieldValue,
|
||||
zModelIdentifierFieldValue,
|
||||
zMainModelFieldValue,
|
||||
@@ -920,6 +947,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
zImageFieldCollectionInputInstance,
|
||||
zBoardFieldInputInstance,
|
||||
zModelIdentifierFieldInputInstance,
|
||||
zMainModelFieldInputInstance,
|
||||
@@ -954,6 +982,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
zImageFieldCollectionInputTemplate,
|
||||
zBoardFieldInputTemplate,
|
||||
zModelIdentifierFieldInputTemplate,
|
||||
zMainModelFieldInputTemplate,
|
||||
@@ -991,6 +1020,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zImageFieldCollectionOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
|
||||
@@ -64,8 +64,6 @@ export type AnyNode = Node<AnyNodeData>;
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
|
||||
@@ -1,16 +1,20 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const { nodes, edges } = nodesState;
|
||||
|
||||
const filteredNodes = nodes.filter(isInvocationNode);
|
||||
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
@@ -47,22 +51,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
return nodesAccumulator;
|
||||
}, {});
|
||||
|
||||
const filteredNodeIds = filteredNodes.map(({ id }) => id);
|
||||
|
||||
// skip out the "dummy" edges between collapsed nodes
|
||||
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
|
||||
const filteredEdges = edges
|
||||
.filter((edge) => edge.type !== 'collapsed')
|
||||
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
|
||||
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
|
||||
return edgesAccumulator;
|
||||
}
|
||||
|
||||
// Format the edges and add to the edges array
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle as string,
|
||||
field: sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle as string,
|
||||
field: targetHandle,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
|
||||
<Text as="span" fontWeight="semibold">
|
||||
{node_path}.{field_name}
|
||||
</Text>
|
||||
: {value}
|
||||
: {JSON.stringify(value)}
|
||||
</Text>
|
||||
))}
|
||||
</Flex>
|
||||
|
||||
Reference in New Issue
Block a user