feat(ui): image batching in workflows

- Add special handling for `ImageBatchInvocation`
- Add input component for image collections, supporting multi-image upload and dnd
- Minor rework of some hooks for accessing node data
This commit is contained in:
psychedelicious
2024-11-15 19:27:49 -08:00
parent e1626a4e49
commit 616c0f11e1
21 changed files with 372 additions and 49 deletions

View File

@@ -1,10 +1,15 @@
import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { Batch, BatchConfig } from 'services/api/types';
const log = logger('workflows');
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
@@ -26,6 +31,33 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
delete builtWorkflow.id;
}
const data: Batch['data'] = [];
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
}
const batchConfig: BatchConfig = {
batch: {
graph,
@@ -33,6 +65,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: action.payload.prepend,
};

View File

@@ -141,11 +141,9 @@ export const useImageUploadButton = ({ onUpload, isDisabled, allowMultiple }: Us
};
const sx = {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 0,
borderRadius: 'base',
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
borderWidth: 1,
},
} satisfies SystemStyleObject;
@@ -164,7 +162,34 @@ export const UploadImageButton = ({
<>
<IconButton
aria-label="Upload image"
variant="ghost"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}
isLoading={uploadApi.request.isLoading}
{...rest}
{...uploadApi.getUploadButtonProps()}
/>
<input {...uploadApi.getUploadInputProps()} />
</>
);
};
export const UploadMultipleImageButton = ({
isDisabled = false,
onUpload,
isError = false,
...rest
}: {
onUpload?: (imageDTOs: ImageDTO[]) => void;
isError?: boolean;
} & SetOptional<IconButtonProps, 'aria-label'>) => {
const uploadApi = useImageUploadButton({ isDisabled, allowMultiple: true, onUpload });
return (
<>
<IconButton
aria-label="Upload image"
variant="outline"
sx={sx}
data-error={isError}
icon={<PiUploadBold />}

View File

@@ -22,12 +22,15 @@ const sx = {
},
} satisfies SystemStyleObject;
type Props = ImageProps & {
imageDTO: ImageDTO;
asThumbnail?: boolean;
};
/* eslint-disable-next-line @typescript-eslint/no-namespace */
export namespace DndImage {
export interface Props extends ImageProps {
imageDTO: ImageDTO;
asThumbnail?: boolean;
}
}
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
const store = useAppStore();
const [isDragging, setIsDragging] = useState(false);
const [element, ref] = useState<HTMLImageElement | null>(null);

View File

@@ -0,0 +1,26 @@
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
import { DndImage } from 'features/dnd/DndImage';
import { memo } from 'react';
import { PiExclamationMarkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
/* eslint-disable-next-line @typescript-eslint/no-namespace */
namespace DndImageFromImageName {
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
imageName: string;
}
}
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
const query = useGetImageDTOQuery(imageName);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;
}
if (!query.data) {
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
}
return <DndImage imageDTO={query.data} {...rest} />;
});
DndImageFromImageName.displayName = 'DndImageFromImageName';

View File

@@ -9,6 +9,7 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
@@ -241,6 +242,45 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
};
//#endregion
//#region Add Images to Image Collection Node Field
const _addImagesToNodeImageFieldCollection = buildTypeAndKey('add-images-to-image-collection-node-field');
export type AddImagesToNodeImageFieldCollection = DndData<
typeof _addImagesToNodeImageFieldCollection.type,
typeof _addImagesToNodeImageFieldCollection.key,
{ fieldIdentifer: FieldIdentifier }
>;
export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
AddImagesToNodeImageFieldCollection,
SingleImageDndSourceData | MultipleImageDndSourceData
> = {
..._addImagesToNodeImageFieldCollection,
typeGuard: buildTypeGuard(_addImagesToNodeImageFieldCollection.key),
getData: buildGetData(_addImagesToNodeImageFieldCollection.key, _addImagesToNodeImageFieldCollection.type),
isValid: ({ sourceData }) => {
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
return true;
}
return false;
},
handler: ({ sourceData, targetData, dispatch }) => {
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
return;
}
const { fieldIdentifer } = targetData.payload;
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
}
addImagesToNodeImageFieldCollectionAction({ fieldIdentifer, imageDTOs, dispatch });
},
};
//#endregion
//# Set Comparison Image
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
export type SetComparisonImageDndTargetData = DndData<
@@ -430,6 +470,7 @@ export const dndTargets = [
// Single or Multiple Image
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
addImagesToNodeImageFieldCollectionDndTarget,
] as const;
export type AnyDndTarget = (typeof dndTargets)[number];

View File

@@ -29,7 +29,7 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
@@ -71,6 +71,15 @@ export const setNodeImageFieldImage = (arg: {
dispatch(fieldImageValueChanged({ ...fieldIdentifer, value: imageDTO }));
};
export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifer: FieldIdentifier;
dispatch: AppDispatch;
}) => {
const { imageDTOs, fieldIdentifer, dispatch } = arg;
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifer, value: imageDTOs }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));

View File

@@ -40,7 +40,7 @@ import { computed } from 'nanostores';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi';
import type { EdgeChange, NodeChange } from 'reactflow';
import type { S } from 'services/api/types';
@@ -403,7 +403,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}
return _items;
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]);
return (
<>
@@ -414,6 +414,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
{item.classification === 'beta' && <Icon boxSize={4} color="invokeYellow.300" as={PiHammerBold} />}
{item.classification === 'prototype' && <Icon boxSize={4} color="invokeRed.300" as={PiFlaskBold} />}
{item.classification === 'internal' && <Icon boxSize={4} color="invokePurple.300" as={PiCircuitryBold} />}
{item.classification === 'special' && <Icon boxSize={4} color="invokeGreen.300" as={PiLightningFill} />}
<Text fontWeight="semibold">{item.label}</Text>
<Spacer />
<Text variant="subtext" fontWeight="semibold">

View File

@@ -3,7 +3,7 @@ import { useNodeClassification } from 'features/nodes/hooks/useNodeClassificatio
import type { Classification } from 'features/nodes/types/common';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold } from 'react-icons/pi';
import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi';
interface Props {
nodeId: string;
@@ -62,5 +62,9 @@ const ClassificationIcon = ({ classification }: { classification: Classification
return <Icon as={PiCircuitryBold} display="block" boxSize={4} color="invokePurple.300" />;
}
if (classification === 'special') {
return <Icon as={PiLightningFill} display="block" boxSize={4} color="invokeGreen.300" />;
}
return null;
};

View File

@@ -1,9 +1,9 @@
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiInfoBold } from 'react-icons/pi';
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
export default memo(InvocationNodeInfoIcon);
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const { t } = useTranslation();
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
if (node.data?.label && nodeTemplate?.title) {
return `${node.data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
if (node.data?.label && !nodeTemplate) {
return node.data.label;
}
if (!data?.label && nodeTemplate) {
if (!node.data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return t('nodes.unknownNode');
}, [data, nodeTemplate, t]);
}, [node.data.label, nodeTemplate, t]);
const versionComponent = useMemo(() => {
if (!isInvocationNodeData(data) || !nodeTemplate) {
if (!isInvocationNode(node) || !nodeTemplate) {
return null;
}
if (!data.version) {
if (!node.data.version) {
return (
<Text as="span" color="error.500">
{t('nodes.versionUnknown')}
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
if (!nodeTemplate.version) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '<')) {
if (compare(node.data.version, nodeTemplate.version, '<')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '>')) {
if (compare(node.data.version, nodeTemplate.version, '>')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
</Text>
);
}
return (
<Text as="span">
{t('nodes.version')} {data.version}
{t('nodes.version')} {node.data.version}
</Text>
);
}, [data, nodeTemplate, t]);
}, [node, nodeTemplate, t]);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
}
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
{nodeTemplate?.description}
</Text>
{versionComponent}
{data?.notes && <Text>{data.notes}</Text>}
{node.data?.notes && <Text>{node.data.notes}</Text>}
</Flex>
);
});

View File

@@ -1,15 +1,15 @@
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const { t } = useTranslation();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return null;
}
return (
<FormControl orientation="vertical" h="full">
<FormLabel>{t('nodes.notes')}</FormLabel>
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
</FormControl>
);
};

View File

@@ -1,3 +1,4 @@
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
@@ -24,6 +25,8 @@ import {
isFluxMainModelFieldInputTemplate,
isFluxVAEModelFieldInputInstance,
isFluxVAEModelFieldInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldInputInstance,
@@ -110,6 +113,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,99 @@
import { Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { UploadMultipleImageButton } from 'common/hooks/useImageUploadButton';
import type { AddImagesToNodeImageFieldCollection } from 'features/dnd/dnd';
import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import type { ImageDTO } from 'services/api/types';
import type { FieldComponentProps } from './types';
export const ImageFieldCollectionInputComponent = memo(
(props: FieldComponentProps<ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate>) => {
const { t } = useTranslation();
const { nodeId, field, fieldTemplate } = props;
const dispatch = useAppDispatch();
const onReset = useCallback(() => {
dispatch(
fieldImageCollectionValueChanged({
nodeId,
fieldName: field.name,
value: [],
})
);
}, [dispatch, field.name, nodeId]);
const dndTargetData = useMemo<AddImagesToNodeImageFieldCollection>(
() => addImagesToNodeImageFieldCollectionDndTarget.getData({ fieldIdentifer: { nodeId, fieldName: field.name } }),
[field, nodeId]
);
const onUpload = useCallback(
(imageDTOs: ImageDTO[]) => {
dispatch(
fieldImageCollectionValueChanged({
nodeId,
fieldName: field.name,
value: imageDTOs,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<Flex
position="relative"
className="nodrag"
w="full"
h="full"
minH={16}
alignItems="stretch"
justifyContent="center"
>
{field.value.length === 0 && (
<UploadMultipleImageButton
w="full"
h="auto"
isError={fieldTemplate.required}
onUpload={onUpload}
fontSize={24}
variant="outline"
/>
)}
{field.value.length > 0 && (
<>
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
{field.value.map(({ image_name }) => (
<GridItem key={image_name}>
<DndImageFromImageName imageName={image_name} asThumbnail />
</GridItem>
))}
</Grid>
<IconButton
aria-label="reset"
icon={<PiArrowCounterClockwiseBold />}
position="absolute"
top={0}
insetInlineEnd={0}
onClick={onReset}
/>
</>
)}
<DndDropTarget
dndTarget={addImagesToNodeImageFieldCollectionDndTarget}
dndTargetData={dndTargetData}
label={t('gallery.drop')}
/>
</Flex>
);
}
);
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';

View File

@@ -0,0 +1,19 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
import type { Node } from 'reactflow';
export const useNode = (nodeId: string): Node => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectNode(nodes, nodeId);
}),
[nodeId]
);
const node = useAppSelector(selector);
return node;
};

View File

@@ -16,6 +16,7 @@ import type {
FieldValue,
FloatFieldValue,
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
@@ -42,6 +43,7 @@ import {
zEnumFieldValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
zIntegerFieldValue,
zIPAdapterModelFieldValue,
@@ -319,6 +321,9 @@ export const nodesSlice = createSlice({
fieldImageValueChanged: (state, action: FieldValueAction<ImageFieldValue>) => {
fieldValueReducer(state, action, zImageFieldValue);
},
fieldImageCollectionValueChanged: (state, action: FieldValueAction<ImageFieldCollectionValue>) => {
fieldValueReducer(state, action, zImageFieldCollectionValue);
},
fieldColorValueChanged: (state, action: FieldValueAction<ColorFieldValue>) => {
fieldValueReducer(state, action, zColorFieldValue);
},
@@ -416,6 +421,7 @@ export const {
fieldControlNetModelValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldImageCollectionValueChanged,
fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
fieldSpandrelImageToImageModelValueChanged,
@@ -527,6 +533,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldControlNetModelValueChanged,
fieldEnumModelValueChanged,
fieldImageValueChanged,
fieldImageCollectionValueChanged,
fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
fieldLabelChanged,

View File

@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Node } from 'reactflow';
import { assert } from 'tsafe';
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(node !== undefined, `Node ${nodeId} not found`);
return node;
};
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);

View File

@@ -22,7 +22,7 @@ export const zColorField = z.object({
});
export type ColorField = z.infer<typeof zColorField>;
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated', 'internal']);
export const zClassification = z.enum(['stable', 'beta', 'prototype', 'deprecated', 'internal', 'special']);
export type Classification = z.infer<typeof zClassification>;
export const zSchedulerField = z.enum([

View File

@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
EnumField: 'blue.500',
FloatField: 'orange.500',
ImageField: 'purple.500',
ImageBatchField: 'purple.500',
IntegerField: 'red.500',
IPAdapterField: 'teal.500',
IPAdapterModelField: 'teal.500',

View File

@@ -95,6 +95,11 @@ const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
originalType: zStatelessFieldType.optional(),
});
const zImageCollectionFieldType = z.object({
name: z.literal('ImageField'),
cardinality: z.literal(zCardinality.Values.COLLECTION),
originalType: zStatelessFieldType.optional(),
});
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
@@ -347,7 +352,6 @@ export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTem
// #endregion
// #region ImageField
export const zImageFieldValue = zImageField.optional();
const zImageFieldInputInstance = zFieldInputInstanceBase.extend({
value: zImageFieldValue,
@@ -369,6 +373,28 @@ export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputT
zImageFieldInputTemplate.safeParse(val).success;
// #endregion
// #region ImageField Collection
export const zImageFieldCollectionValue = z.array(zImageField);
const zImageFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zImageFieldCollectionValue,
});
const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase.extend({
type: zImageCollectionFieldType,
originalType: zFieldType.optional(),
default: zImageFieldCollectionValue,
});
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zImageCollectionFieldType,
});
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
zImageFieldCollectionInputInstance.safeParse(val).success;
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
zImageFieldCollectionInputTemplate.safeParse(val).success;
// #endregion
// #region BoardField
export const zBoardFieldValue = zBoardField.optional();
@@ -885,6 +911,7 @@ export const zStatefulFieldValue = z.union([
zBooleanFieldValue,
zEnumFieldValue,
zImageFieldValue,
zImageFieldCollectionValue,
zBoardFieldValue,
zModelIdentifierFieldValue,
zMainModelFieldValue,
@@ -920,6 +947,7 @@ const zStatefulFieldInputInstance = z.union([
zBooleanFieldInputInstance,
zEnumFieldInputInstance,
zImageFieldInputInstance,
zImageFieldCollectionInputInstance,
zBoardFieldInputInstance,
zModelIdentifierFieldInputInstance,
zMainModelFieldInputInstance,
@@ -954,6 +982,7 @@ const zStatefulFieldInputTemplate = z.union([
zBooleanFieldInputTemplate,
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
zImageFieldCollectionInputTemplate,
zBoardFieldInputTemplate,
zModelIdentifierFieldInputTemplate,
zMainModelFieldInputTemplate,
@@ -991,6 +1020,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,
zImageFieldCollectionOutputTemplate,
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,

View File

@@ -64,8 +64,6 @@ export type AnyNode = Node<AnyNodeData>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
// #endregion
// #region NodeExecutionState

View File

@@ -1,16 +1,20 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const log = logger('workflows');
/**
* Builds a graph from the node editor state.
*/
export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
const filteredNodes = nodes.filter(isInvocationNode);
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
@@ -47,22 +51,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
return nodesAccumulator;
}, {});
const filteredNodeIds = filteredNodes.map(({ id }) => id);
// skip out the "dummy" edges between collapsed nodes
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
const filteredEdges = edges
.filter((edge) => edge.type !== 'collapsed')
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
// Reduce the node editor edges into invocation graph edges
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
if (!sourceHandle || !targetHandle) {
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
return edgesAccumulator;
}
// Format the edges and add to the edges array
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
field: sourceHandle,
},
destination: {
node_id: target,
field: targetHandle as string,
field: targetHandle,
},
});

View File

@@ -101,7 +101,7 @@ const QueueItemComponent = ({ index, item, context }: InnerItemProps) => {
<Text as="span" fontWeight="semibold">
{node_path}.{field_name}
</Text>
: {value}
: {JSON.stringify(value)}
</Text>
))}
</Flex>