Compare commits

...

1 Commits

Author SHA1 Message Date
psychedelicious
c3304dcbe3 feat(ui): add image batching to workflows (wip) 2024-11-13 14:23:04 -08:00
25 changed files with 444 additions and 66 deletions

View File

@@ -16,6 +16,7 @@ from pydantic import (
from pydantic_core import to_jsonable_python
from invokeai.app.invocations.baseinvocation import BaseInvocation
from invokeai.app.invocations.fields import ImageField
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
from invokeai.app.services.workflow_records.workflow_records_common import (
WorkflowWithoutID,
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
# region Batch
BatchDataType = Union[
StrictStr,
float,
int,
]
BatchDataType = Union[StrictStr, float, int, ImageField]
class NodeFieldValue(BaseModel):

View File

@@ -1,10 +1,11 @@
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageBatchNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { queueApi } from 'services/api/endpoints/queue';
import type { BatchConfig } from 'services/api/types';
import type { Batch, BatchConfig } from 'services/api/types';
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
startAppListening({
@@ -26,6 +27,18 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
delete builtWorkflow.id;
}
const data: Batch['data'] = [];
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isImageBatchNode);
for (const imageBatch of imageBatchNodes) {
const edge = nodes.edges.find((e) => e.source === imageBatch.id);
if (!edge || !edge.targetHandle) {
break;
}
data.push([{ node_path: edge.target, field_name: edge.targetHandle, items: imageBatch.data.images }]);
}
const batchConfig: BatchConfig = {
batch: {
graph,
@@ -33,6 +46,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend: action.payload.prepend,
};

View File

@@ -20,12 +20,15 @@ const sx = {
},
} satisfies SystemStyleObject;
type Props = ImageProps & {
imageDTO: ImageDTO;
asThumbnail?: boolean;
};
/* eslint-disable-next-line @typescript-eslint/no-namespace */
export namespace DndImage {
export interface Props extends ImageProps {
imageDTO: ImageDTO;
asThumbnail?: boolean;
}
}
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
const store = useAppStore();
const [isDragging, setIsDragging] = useState(false);
const [element, ref] = useState<HTMLImageElement | null>(null);

View File

@@ -0,0 +1,26 @@
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
import { DndImage } from 'features/dnd/DndImage';
import { memo } from 'react';
import { PiExclamationMarkBold } from 'react-icons/pi';
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
/* eslint-disable-next-line @typescript-eslint/no-namespace */
namespace DndImageFromImageName {
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
imageName: string;
}
}
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
const query = useGetImageDTOQuery(imageName);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;
}
if (!query.data) {
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
}
return <DndImage imageDTO={query.data} {...rest} />;
});
DndImageFromImageName.displayName = 'DndImageFromImageName';

View File

@@ -18,6 +18,7 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import { batchImageInputNodeImagesAdded } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
@@ -241,6 +242,47 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
};
//#endregion
//#region Add Images to Batch Image Input Node
const _addImagesToBatchImageInputNode = buildTypeAndKey('add-images-to-batch-image-input-node');
export type AddImagesToBatchImageInputNodeDndTargetData = DndData<
typeof _addImagesToBatchImageInputNode.type,
typeof _addImagesToBatchImageInputNode.key,
{ nodeId: string }
>;
export const addImagesToBatchImageInputNodeDndTarget: DndTarget<
AddImagesToBatchImageInputNodeDndTargetData,
SingleImageDndSourceData | MultipleImageDndSourceData
> = {
..._addImagesToBatchImageInputNode,
typeGuard: buildTypeGuard(_addImagesToBatchImageInputNode.key),
getData: buildGetData(_addImagesToBatchImageInputNode.key, _addImagesToBatchImageInputNode.type),
isValid: ({ sourceData }) => {
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
return true;
}
return false;
},
handler: ({ sourceData, targetData, dispatch }) => {
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
return;
}
const { nodeId } = targetData.payload;
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
}
const images = imageDTOs.map(({ image_name }) => ({ image_name }));
dispatch(batchImageInputNodeImagesAdded({ nodeId, images }));
},
};
//#endregion
//# Set Comparison Image
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
export type SetComparisonImageDndTargetData = DndData<
@@ -430,6 +472,7 @@ export const dndTargets = [
// Single or Multiple Image
addImageToBoardDndTarget,
removeImageFromBoardDndTarget,
addImagesToBatchImageInputNodeDndTarget,
] as const;
export type AnyDndTarget = (typeof dndTargets)[number];

View File

@@ -347,6 +347,18 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}),
[t]
);
const batchInputImageItem = useMemo<FilterableItem>(
() =>
({
type: 'batch_input',
title: 'batch_input',
description: 'batch_input',
tags: ['batch'],
classification: 'stable',
nodePack: 'invokeai',
}) as const,
[]
);
const items = useMemo<NodeCommandItemData[]>(() => {
// If we have a connection in progress, we need to filter the node choices
@@ -365,7 +377,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}
}
for (const item of [currentImageFilterItem, notesFilterItem]) {
for (const item of [currentImageFilterItem, notesFilterItem, batchInputImageItem]) {
if (filter(item, searchTerm)) {
_items.push({
label: item.title,
@@ -403,7 +415,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
}
return _items;
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem, batchInputImageItem]);
return (
<>

View File

@@ -2,6 +2,7 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
import { BatchImageInputNode } from 'features/nodes/components/flow/nodes/BatchInput/BatchImageInputNode';
import { useConnection } from 'features/nodes/hooks/useConnection';
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
@@ -64,6 +65,7 @@ const nodeTypes = {
invocation: InvocationNodeWrapper,
current_image: CurrentImageNode,
notes: NotesNode,
image_batch: BatchImageInputNode,
};
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app

View File

@@ -0,0 +1,119 @@
import { Box, Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import type { AddImagesToBatchImageInputNodeDndTargetData } from 'features/dnd/dnd';
import { addImagesToBatchImageInputNodeDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
import FieldHandle from 'features/nodes/components/flow/nodes/Invocation/fields/FieldHandle';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { batchImageInputNodeReset } from 'features/nodes/store/nodesSlice';
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
import type { ImageBatchNodeData } from 'features/nodes/types/invocation';
import { memo, useCallback, useMemo } from 'react';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
import type { NodeProps } from 'reactflow';
export const BatchImageInputNode = memo((props: NodeProps<ImageBatchNodeData>) => {
const { id: nodeId, data, selected } = props;
const { images, isOpen } = data;
const dispatch = useAppDispatch();
const onReset = useCallback(() => {
dispatch(batchImageInputNodeReset({ nodeId }));
}, [dispatch, nodeId]);
const targetData = useMemo<AddImagesToBatchImageInputNodeDndTargetData>(
() => addImagesToBatchImageInputNodeDndTarget.getData({ nodeId }),
[nodeId]
);
return (
<NodeWrapper nodeId={nodeId} selected={selected}>
<Flex
layerStyle="nodeHeader"
borderTopRadius="base"
borderBottomRadius={isOpen ? 0 : 'base'}
alignItems="center"
justifyContent="space-between"
h={8}
>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<NodeTitle nodeId={nodeId} title="Batch Image Input" />
<Box minW={8} />
</Flex>
{isOpen && (
<>
<Flex
position="relative"
layerStyle="nodeBody"
className="nopan"
cursor="auto"
flexDirection="column"
borderBottomRadius="base"
w="full"
h="full"
p={2}
gap={1}
minH={16}
>
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
{images.map(({ image_name }) => (
<GridItem key={image_name}>
<DndImageFromImageName imageName={image_name} asThumbnail />
</GridItem>
))}
</Grid>
<IconButton
aria-label="reset"
icon={<PiArrowCounterClockwiseBold />}
position="absolute"
top={0}
insetInlineEnd={0}
onClick={onReset}
variant="ghost"
/>
</Flex>
</>
)}
<ImageBatchOutputField nodeId={nodeId} />
<DndDropTarget
dndTarget={addImagesToBatchImageInputNodeDndTarget}
dndTargetData={targetData}
label="Add to Batch"
/>
</NodeWrapper>
);
});
BatchImageInputNode.displayName = 'BatchImageInputNode';
const ImageBatchOutputField = memo(({ nodeId }: { nodeId: string }) => {
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName: 'images', kind: 'outputs' });
return (
<Flex
position="absolute"
minH={8}
top="50%"
translateY="-50%"
insetInlineEnd={2}
alignItems="center"
opacity={shouldDim ? 0.5 : 1}
transitionProperty="opacity"
transitionDuration="0.1s"
justifyContent="flex-end"
>
<FieldHandle
fieldTemplate={imageBatchOutputFieldTemplate}
handleType="source"
isConnectionInProgress={isConnectionInProgress}
isConnectionStartField={isConnectionStartField}
validationResult={validationResult}
/>
</Flex>
);
});
ImageBatchOutputField.displayName = 'ImageBatchOutputField';

View File

@@ -1,9 +1,9 @@
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
import { compare } from 'compare-versions';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiInfoBold } from 'react-icons/pi';
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
export default memo(InvocationNodeInfoIcon);
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const nodeTemplate = useNodeTemplate(nodeId);
const { t } = useTranslation();
const title = useMemo(() => {
if (data?.label && nodeTemplate?.title) {
return `${data.label} (${nodeTemplate.title})`;
if (node.data?.label && nodeTemplate?.title) {
return `${node.data.label} (${nodeTemplate.title})`;
}
if (data?.label && !nodeTemplate) {
return data.label;
if (node.data?.label && !nodeTemplate) {
return node.data.label;
}
if (!data?.label && nodeTemplate) {
if (!node.data?.label && nodeTemplate) {
return nodeTemplate.title;
}
return t('nodes.unknownNode');
}, [data, nodeTemplate, t]);
}, [node.data.label, nodeTemplate, t]);
const versionComponent = useMemo(() => {
if (!isInvocationNodeData(data) || !nodeTemplate) {
if (!isInvocationNode(node) || !nodeTemplate) {
return null;
}
if (!data.version) {
if (!node.data.version) {
return (
<Text as="span" color="error.500">
{t('nodes.versionUnknown')}
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
if (!nodeTemplate.version) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '<')) {
if (compare(node.data.version, nodeTemplate.version, '<')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
</Text>
);
}
if (compare(data.version, nodeTemplate.version, '>')) {
if (compare(node.data.version, nodeTemplate.version, '>')) {
return (
<Text as="span" color="error.500">
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
</Text>
);
}
return (
<Text as="span">
{t('nodes.version')} {data.version}
{t('nodes.version')} {node.data.version}
</Text>
);
}, [data, nodeTemplate, t]);
}, [node, nodeTemplate, t]);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
}
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
{nodeTemplate?.description}
</Text>
{versionComponent}
{data?.notes && <Text>{data.notes}</Text>}
{node.data?.notes && <Text>{node.data.notes}</Text>}
</Flex>
);
});

View File

@@ -1,15 +1,15 @@
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useNodeData } from 'features/nodes/hooks/useNodeData';
import { useNode } from 'features/nodes/hooks/useNode';
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
import { isInvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
const dispatch = useAppDispatch();
const data = useNodeData(nodeId);
const node = useNode(nodeId);
const { t } = useTranslation();
const handleNotesChanged = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
},
[dispatch, nodeId]
);
if (!isInvocationNodeData(data)) {
if (!isInvocationNode(node)) {
return null;
}
return (
<FormControl orientation="vertical" h="full">
<FormLabel>{t('nodes.notes')}</FormLabel>
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
</FormControl>
);
};

View File

@@ -23,6 +23,8 @@ const InputField = ({ nodeId, fieldName }: Props) => {
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
console.log({ isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim });
const isMissingInput = useMemo(() => {
if (!fieldTemplate) {
return false;

View File

@@ -64,7 +64,7 @@ type OutputFieldWrapperProps = PropsWithChildren<{
shouldDim: boolean;
}>;
const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
export const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
<Flex
position="relative"
minH={8}

View File

@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
import { $templates } from 'features/nodes/store/nodesSlice';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildBatchInputImageNode } from 'features/nodes/util/node/buildBatchInputImageNode';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
return useCallback(
// string here is "any invocation type"
(type: string | 'current_image' | 'notes'): AnyNode => {
(type: string | 'current_image' | 'notes' | 'batch_input'): AnyNode => {
let _x = window.innerWidth / 2;
let _y = window.innerHeight / 2;
@@ -39,6 +40,10 @@ export const useBuildNode = () => {
return buildNotesNode(position);
}
if (type === 'batch_input') {
return buildBatchInputImageNode(position);
}
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = templates[type] as InvocationTemplate;

View File

@@ -12,6 +12,11 @@ import {
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import {
type FieldInputTemplate,
type FieldOutputTemplate,
imageBatchOutputFieldTemplate,
} from 'features/nodes/types/field';
import { useCallback, useMemo } from 'react';
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
import { useUpdateNodeInternals } from 'reactflow';
@@ -33,13 +38,20 @@ export const useConnection = () => {
return;
}
const template = templates[node.data.type];
if (!template) {
return;
let fieldTemplate: FieldInputTemplate | FieldOutputTemplate | undefined = undefined;
if (node.type === 'image_batch' && handleId === 'images' && handleType === 'source') {
fieldTemplate = imageBatchOutputFieldTemplate;
} else {
const template = templates[node.data.type];
if (!template) {
return;
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
fieldTemplate = fieldTemplates[handleId];
}
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
const fieldTemplate = fieldTemplates[handleId];
if (!fieldTemplate) {
return;
}

View File

@@ -0,0 +1,19 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
import { useMemo } from 'react';
import type { Node } from 'reactflow';
export const useNode = (nodeId: string): Node => {
const selector = useMemo(
() =>
createMemoizedSelector(selectNodesSlice, (nodes) => {
return selectNode(nodes, nodeId);
}),
[nodeId]
);
const node = useAppSelector(selector);
return node;
};

View File

@@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
import type { PersistConfig } from 'app/store/store';
import { buildUseBoolean } from 'common/hooks/useBoolean';
import { workflowLoaded } from 'features/nodes/store/actions';
import type { ImageField } from 'features/nodes/types/common';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -58,7 +59,8 @@ import {
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { isImageBatchNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { uniqBy } from 'lodash-es';
import { atom, computed } from 'nanostores';
import type { MouseEvent } from 'react';
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
@@ -193,7 +195,7 @@ export const nodesSlice = createSlice({
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node) && !isNotesNode(node)) {
if (!isInvocationNode(node) && !isNotesNode(node) && !isImageBatchNode(node)) {
return;
}
@@ -382,6 +384,34 @@ export const nodesSlice = createSlice({
}
node.data.notes = value;
},
batchImageInputNodeImagesAdded: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
const { nodeId, images } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
node.data.images = uniqBy([...node.data.images, ...images], 'image_name');
},
batchImageInputNodeImagesRemoved: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
const { nodeId, images } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
const imageNamesToRemove = images.map(({ image_name }) => image_name);
node.data.images = uniqBy(
node.data.images.filter(({ image_name }) => imageNamesToRemove.includes(image_name)),
'image_name'
);
},
batchImageInputNodeReset: (state, action: PayloadAction<{ nodeId: string }>) => {
const { nodeId } = action.payload;
const node = state.nodes.find((n) => n.id === nodeId);
if (!isImageBatchNode(node)) {
return;
}
node.data.images = [];
},
nodeEditorReset: (state) => {
state.nodes = [];
state.edges = [];
@@ -443,6 +473,9 @@ export const {
notesNodeValueChanged,
undo,
redo,
batchImageInputNodeImagesAdded,
batchImageInputNodeImagesRemoved,
batchImageInputNodeReset,
} = nodesSlice.actions;
export const $cursorPos = atom<XYPosition | null>(null);

View File

@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
import type { FieldInputInstance } from 'features/nodes/types/field';
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { Node } from 'reactflow';
import { assert } from 'tsafe';
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(node !== undefined, `Node ${nodeId} not found`);
return node;
};
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);

View File

@@ -3,6 +3,7 @@ import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
import type { AnyNode } from 'features/nodes/types/invocation';
import type { Connection as NullableConnection, Edge } from 'reactflow';
import type { SetNonNullable } from 'type-fest';
@@ -78,23 +79,32 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
return buildRejectResult('nodes.missingNode');
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
if (sourceNode.type === 'image_batch') {
const isValid = validateConnectionTypes(imageBatchOutputFieldTemplate.type, targetFieldTemplate.type);
if (!isValid) {
return buildRejectResult('nodes.fieldTypesMustMatch');
} else {
return buildAcceptResult();
}
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
}

View File

@@ -15,6 +15,10 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
return false;
}
if (sourceType.name === 'ImageBatchField') {
return isSingle(sourceType) && targetType.name === 'ImageField';
}
if (areTypesEqual(sourceType, targetType)) {
return true;
}

View File

@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
EnumField: 'blue.500',
FloatField: 'orange.500',
ImageField: 'purple.500',
ImageBatchField: 'purple.500',
IntegerField: 'red.500',
IPAdapterField: 'teal.500',
IPAdapterModelField: 'teal.500',

View File

@@ -95,6 +95,10 @@ const zImageFieldType = zFieldTypeBase.extend({
name: z.literal('ImageField'),
originalType: zStatelessFieldType.optional(),
});
const zImageBatchFieldType = zFieldTypeBase.extend({
name: z.literal('ImageBatchField'),
originalType: zStatelessFieldType.optional(),
});
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
@@ -175,13 +179,14 @@ const zSchedulerFieldType = zFieldTypeBase.extend({
name: z.literal('SchedulerField'),
originalType: zStatelessFieldType.optional(),
});
const zStatefulFieldType = z.union([
export const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
zStringFieldType,
zBooleanFieldType,
zEnumFieldType,
zImageFieldType,
zImageBatchFieldType,
zBoardFieldType,
zModelIdentifierFieldType,
zMainModelFieldType,
@@ -367,6 +372,19 @@ export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputI
zImageFieldInputInstance.safeParse(val).success;
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
zImageFieldInputTemplate.safeParse(val).success;
const zImageBatchFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zImageBatchFieldType,
});
export type ImageBatchOutputFieldTemplate = z.infer<typeof zImageBatchFieldOutputTemplate>;
export const imageBatchOutputFieldTemplate: ImageBatchOutputFieldTemplate = {
fieldKind: 'output',
name: 'images',
ui_hidden: false,
type: { name: 'ImageBatchField', cardinality: 'SINGLE' },
title: 'Image',
};
// #endregion
// #region BoardField
@@ -991,6 +1009,7 @@ const zStatefulFieldOutputTemplate = z.union([
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,
zImageBatchFieldOutputTemplate,
zBoardFieldOutputTemplate,
zModelIdentifierFieldOutputTemplate,
zMainModelFieldOutputTemplate,

View File

@@ -1,7 +1,7 @@
import type { Edge, Node } from 'reactflow';
import { z } from 'zod';
import { zClassification, zProgressImage } from './common';
import { zClassification, zImageField, zProgressImage } from './common';
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
import { zSemVer } from './semver';
@@ -49,23 +49,33 @@ const zCurrentImageNodeData = z.object({
label: z.string(),
isOpen: z.boolean(),
});
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
const zImageBatchNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('image_batch'),
label: z.string(),
isOpen: z.boolean(),
images: z.array(zImageField),
});
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData, zImageBatchNodeData]);
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
export type ImageBatchNodeData = z.infer<typeof zImageBatchNodeData>;
type AnyNodeData = z.infer<typeof zAnyNodeData>;
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
export type NotesNode = Node<NotesNodeData, 'notes'>;
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
export type BatchImageInputNode = Node<ImageBatchNodeData, 'image_batch'>;
export type AnyNode = Node<AnyNodeData>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
export const isImageBatchNode = (node?: AnyNode | null): node is BatchImageInputNode =>
Boolean(node && node.type === 'image_batch');
// #endregion
// #region NodeExecutionState

View File

@@ -1,9 +1,12 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
const log = logger('workflows');
/**
* Builds a graph from the node editor state.
*/
@@ -47,22 +50,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
return nodesAccumulator;
}, {});
const filteredNodeIds = filteredNodes.map(({ id }) => id);
// skip out the "dummy" edges between collapsed nodes
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
const filteredEdges = edges
.filter((edge) => edge.type !== 'collapsed')
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
// Reduce the node editor edges into invocation graph edges
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
if (!sourceHandle || !targetHandle) {
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
return edgesAccumulator;
}
// Format the edges and add to the edges array
edgesAccumulator.push({
source: {
node_id: source,
field: sourceHandle as string,
field: sourceHandle,
},
destination: {
node_id: target,
field: targetHandle as string,
field: targetHandle,
},
});

View File

@@ -0,0 +1,22 @@
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { BatchImageInputNode } from 'features/nodes/types/invocation';
import type { XYPosition } from 'reactflow';
import { v4 as uuidv4 } from 'uuid';
export const buildBatchInputImageNode = (position: XYPosition): BatchImageInputNode => {
const nodeId = uuidv4();
const node: BatchImageInputNode = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'image_batch',
position,
data: {
id: nodeId,
isOpen: true,
label: 'Image Batch',
type: 'image_batch',
images: [],
},
};
return node;
};

View File

@@ -1822,7 +1822,7 @@ export type components = {
* Items
* @description The list of items to substitute into the node/field.
*/
items?: (string | number)[];
items?: (string | number | components["schemas"]["ImageField"])[];
};
/**
* BatchEnqueuedEvent
@@ -6751,6 +6751,12 @@ export type components = {
* @default 1
*/
denoising_end?: number;
/**
* Add Noise
* @description Add noise based on denoising start.
* @default true
*/
add_noise?: boolean;
/**
* Transformer
* @description Flux model (Transformer) to load
@@ -13161,7 +13167,7 @@ export type components = {
* Value
* @description The value to substitute into the node/field.
*/
value: string | number;
value: string | number | components["schemas"]["ImageField"];
};
/**
* Noise