mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
Compare commits
1 Commits
v5.9.1
...
psyche/fea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c3304dcbe3 |
@@ -16,6 +16,7 @@ from pydantic import (
|
||||
from pydantic_core import to_jsonable_python
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
||||
from invokeai.app.invocations.fields import ImageField
|
||||
from invokeai.app.services.shared.graph import Graph, GraphExecutionState, NodeNotFoundError
|
||||
from invokeai.app.services.workflow_records.workflow_records_common import (
|
||||
WorkflowWithoutID,
|
||||
@@ -51,11 +52,7 @@ class SessionQueueItemNotFoundError(ValueError):
|
||||
|
||||
# region Batch
|
||||
|
||||
BatchDataType = Union[
|
||||
StrictStr,
|
||||
float,
|
||||
int,
|
||||
]
|
||||
BatchDataType = Union[StrictStr, float, int, ImageField]
|
||||
|
||||
|
||||
class NodeFieldValue(BaseModel):
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageBatchNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { queueApi } from 'services/api/endpoints/queue';
|
||||
import type { BatchConfig } from 'services/api/types';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
@@ -26,6 +27,18 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
delete builtWorkflow.id;
|
||||
}
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isImageBatchNode);
|
||||
for (const imageBatch of imageBatchNodes) {
|
||||
const edge = nodes.edges.find((e) => e.source === imageBatch.id);
|
||||
if (!edge || !edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
data.push([{ node_path: edge.target, field_name: edge.targetHandle, items: imageBatch.data.images }]);
|
||||
}
|
||||
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
graph,
|
||||
@@ -33,6 +46,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend: action.payload.prepend,
|
||||
};
|
||||
|
||||
@@ -20,12 +20,15 @@ const sx = {
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
type Props = ImageProps & {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
};
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
export namespace DndImage {
|
||||
export interface Props extends ImageProps {
|
||||
imageDTO: ImageDTO;
|
||||
asThumbnail?: boolean;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: Props) => {
|
||||
export const DndImage = memo(({ imageDTO, asThumbnail, ...rest }: DndImage.Props) => {
|
||||
const store = useAppStore();
|
||||
const [isDragging, setIsDragging] = useState(false);
|
||||
const [element, ref] = useState<HTMLImageElement | null>(null);
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import { IAINoContentFallback, IAINoContentFallbackWithSpinner } from 'common/components/IAIImageFallback';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { memo } from 'react';
|
||||
import { PiExclamationMarkBold } from 'react-icons/pi';
|
||||
import { useGetImageDTOQuery } from 'services/api/endpoints/images';
|
||||
|
||||
/* eslint-disable-next-line @typescript-eslint/no-namespace */
|
||||
namespace DndImageFromImageName {
|
||||
export interface Props extends Omit<DndImage.Props, 'imageDTO'> {
|
||||
imageName: string;
|
||||
}
|
||||
}
|
||||
|
||||
export const DndImageFromImageName = memo(({ imageName, ...rest }: DndImageFromImageName.Props) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
}
|
||||
if (!query.data) {
|
||||
return <IAINoContentFallback icon={<PiExclamationMarkBold />} />;
|
||||
}
|
||||
|
||||
return <DndImage imageDTO={query.data} {...rest} />;
|
||||
});
|
||||
|
||||
DndImageFromImageName.displayName = 'DndImageFromImageName';
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import { batchImageInputNodeImagesAdded } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
@@ -241,6 +242,47 @@ export const setNodeImageFieldImageDndTarget: DndTarget<SetNodeImageFieldImageDn
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//#region Add Images to Batch Image Input Node
|
||||
const _addImagesToBatchImageInputNode = buildTypeAndKey('add-images-to-batch-image-input-node');
|
||||
export type AddImagesToBatchImageInputNodeDndTargetData = DndData<
|
||||
typeof _addImagesToBatchImageInputNode.type,
|
||||
typeof _addImagesToBatchImageInputNode.key,
|
||||
{ nodeId: string }
|
||||
>;
|
||||
export const addImagesToBatchImageInputNodeDndTarget: DndTarget<
|
||||
AddImagesToBatchImageInputNodeDndTargetData,
|
||||
SingleImageDndSourceData | MultipleImageDndSourceData
|
||||
> = {
|
||||
..._addImagesToBatchImageInputNode,
|
||||
typeGuard: buildTypeGuard(_addImagesToBatchImageInputNode.key),
|
||||
getData: buildGetData(_addImagesToBatchImageInputNode.key, _addImagesToBatchImageInputNode.type),
|
||||
isValid: ({ sourceData }) => {
|
||||
if (singleImageDndSource.typeGuard(sourceData) || multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
},
|
||||
handler: ({ sourceData, targetData, dispatch }) => {
|
||||
if (!singleImageDndSource.typeGuard(sourceData) && !multipleImageDndSource.typeGuard(sourceData)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const { nodeId } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
}
|
||||
|
||||
const images = imageDTOs.map(({ image_name }) => ({ image_name }));
|
||||
|
||||
dispatch(batchImageInputNodeImagesAdded({ nodeId, images }));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
//# Set Comparison Image
|
||||
const _setComparisonImage = buildTypeAndKey('set-comparison-image');
|
||||
export type SetComparisonImageDndTargetData = DndData<
|
||||
@@ -430,6 +472,7 @@ export const dndTargets = [
|
||||
// Single or Multiple Image
|
||||
addImageToBoardDndTarget,
|
||||
removeImageFromBoardDndTarget,
|
||||
addImagesToBatchImageInputNodeDndTarget,
|
||||
] as const;
|
||||
|
||||
export type AnyDndTarget = (typeof dndTargets)[number];
|
||||
|
||||
@@ -347,6 +347,18 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}),
|
||||
[t]
|
||||
);
|
||||
const batchInputImageItem = useMemo<FilterableItem>(
|
||||
() =>
|
||||
({
|
||||
type: 'batch_input',
|
||||
title: 'batch_input',
|
||||
description: 'batch_input',
|
||||
tags: ['batch'],
|
||||
classification: 'stable',
|
||||
nodePack: 'invokeai',
|
||||
}) as const,
|
||||
[]
|
||||
);
|
||||
|
||||
const items = useMemo<NodeCommandItemData[]>(() => {
|
||||
// If we have a connection in progress, we need to filter the node choices
|
||||
@@ -365,7 +377,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}
|
||||
}
|
||||
|
||||
for (const item of [currentImageFilterItem, notesFilterItem]) {
|
||||
for (const item of [currentImageFilterItem, notesFilterItem, batchInputImageItem]) {
|
||||
if (filter(item, searchTerm)) {
|
||||
_items.push({
|
||||
label: item.title,
|
||||
@@ -403,7 +415,7 @@ const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; on
|
||||
}
|
||||
|
||||
return _items;
|
||||
}, [pendingConnection, currentImageFilterItem, searchTerm, notesFilterItem, templatesArray]);
|
||||
}, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem, batchInputImageItem]);
|
||||
|
||||
return (
|
||||
<>
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks';
|
||||
import { useFocusRegion, useIsRegionFocused } from 'common/hooks/focus';
|
||||
import { BatchImageInputNode } from 'features/nodes/components/flow/nodes/BatchInput/BatchImageInputNode';
|
||||
import { useConnection } from 'features/nodes/hooks/useConnection';
|
||||
import { useCopyPaste } from 'features/nodes/hooks/useCopyPaste';
|
||||
import { useSyncExecutionState } from 'features/nodes/hooks/useExecutionState';
|
||||
@@ -64,6 +65,7 @@ const nodeTypes = {
|
||||
invocation: InvocationNodeWrapper,
|
||||
current_image: CurrentImageNode,
|
||||
notes: NotesNode,
|
||||
image_batch: BatchImageInputNode,
|
||||
};
|
||||
|
||||
// TODO: can we support reactflow? if not, we could style the attribution so it matches the app
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
import { Box, Flex, Grid, GridItem, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import type { AddImagesToBatchImageInputNodeDndTargetData } from 'features/dnd/dnd';
|
||||
import { addImagesToBatchImageInputNodeDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImageFromImageName } from 'features/dnd/DndImageFromImageName';
|
||||
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
|
||||
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
|
||||
import NodeWrapper from 'features/nodes/components/flow/nodes/common/NodeWrapper';
|
||||
import FieldHandle from 'features/nodes/components/flow/nodes/Invocation/fields/FieldHandle';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { batchImageInputNodeReset } from 'features/nodes/store/nodesSlice';
|
||||
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
|
||||
import type { ImageBatchNodeData } from 'features/nodes/types/invocation';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
import type { NodeProps } from 'reactflow';
|
||||
|
||||
export const BatchImageInputNode = memo((props: NodeProps<ImageBatchNodeData>) => {
|
||||
const { id: nodeId, data, selected } = props;
|
||||
const { images, isOpen } = data;
|
||||
const dispatch = useAppDispatch();
|
||||
const onReset = useCallback(() => {
|
||||
dispatch(batchImageInputNodeReset({ nodeId }));
|
||||
}, [dispatch, nodeId]);
|
||||
const targetData = useMemo<AddImagesToBatchImageInputNodeDndTargetData>(
|
||||
() => addImagesToBatchImageInputNodeDndTarget.getData({ nodeId }),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<NodeWrapper nodeId={nodeId} selected={selected}>
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
borderTopRadius="base"
|
||||
borderBottomRadius={isOpen ? 0 : 'base'}
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
h={8}
|
||||
>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<NodeTitle nodeId={nodeId} title="Batch Image Input" />
|
||||
<Box minW={8} />
|
||||
</Flex>
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
position="relative"
|
||||
layerStyle="nodeBody"
|
||||
className="nopan"
|
||||
cursor="auto"
|
||||
flexDirection="column"
|
||||
borderBottomRadius="base"
|
||||
w="full"
|
||||
h="full"
|
||||
p={2}
|
||||
gap={1}
|
||||
minH={16}
|
||||
>
|
||||
<Grid className="nopan" w="full" h="full" templateColumns="repeat(3, 1fr)" gap={2}>
|
||||
{images.map(({ image_name }) => (
|
||||
<GridItem key={image_name}>
|
||||
<DndImageFromImageName imageName={image_name} asThumbnail />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
<IconButton
|
||||
aria-label="reset"
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
position="absolute"
|
||||
top={0}
|
||||
insetInlineEnd={0}
|
||||
onClick={onReset}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
</>
|
||||
)}
|
||||
<ImageBatchOutputField nodeId={nodeId} />
|
||||
|
||||
<DndDropTarget
|
||||
dndTarget={addImagesToBatchImageInputNodeDndTarget}
|
||||
dndTargetData={targetData}
|
||||
label="Add to Batch"
|
||||
/>
|
||||
</NodeWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
BatchImageInputNode.displayName = 'BatchImageInputNode';
|
||||
|
||||
const ImageBatchOutputField = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName: 'images', kind: 'outputs' });
|
||||
|
||||
return (
|
||||
<Flex
|
||||
position="absolute"
|
||||
minH={8}
|
||||
top="50%"
|
||||
translateY="-50%"
|
||||
insetInlineEnd={2}
|
||||
alignItems="center"
|
||||
opacity={shouldDim ? 0.5 : 1}
|
||||
transitionProperty="opacity"
|
||||
transitionDuration="0.1s"
|
||||
justifyContent="flex-end"
|
||||
>
|
||||
<FieldHandle
|
||||
fieldTemplate={imageBatchOutputFieldTemplate}
|
||||
handleType="source"
|
||||
isConnectionInProgress={isConnectionInProgress}
|
||||
isConnectionStartField={isConnectionStartField}
|
||||
validationResult={validationResult}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
ImageBatchOutputField.displayName = 'ImageBatchOutputField';
|
||||
@@ -1,9 +1,9 @@
|
||||
import { Flex, Icon, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { compare } from 'compare-versions';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate';
|
||||
import { useNodeTemplate } from 'features/nodes/hooks/useNodeTemplate';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiInfoBold } from 'react-icons/pi';
|
||||
@@ -25,32 +25,32 @@ const InvocationNodeInfoIcon = ({ nodeId }: Props) => {
|
||||
export default memo(InvocationNodeInfoIcon);
|
||||
|
||||
const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const nodeTemplate = useNodeTemplate(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const title = useMemo(() => {
|
||||
if (data?.label && nodeTemplate?.title) {
|
||||
return `${data.label} (${nodeTemplate.title})`;
|
||||
if (node.data?.label && nodeTemplate?.title) {
|
||||
return `${node.data.label} (${nodeTemplate.title})`;
|
||||
}
|
||||
|
||||
if (data?.label && !nodeTemplate) {
|
||||
return data.label;
|
||||
if (node.data?.label && !nodeTemplate) {
|
||||
return node.data.label;
|
||||
}
|
||||
|
||||
if (!data?.label && nodeTemplate) {
|
||||
if (!node.data?.label && nodeTemplate) {
|
||||
return nodeTemplate.title;
|
||||
}
|
||||
|
||||
return t('nodes.unknownNode');
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node.data.label, nodeTemplate, t]);
|
||||
|
||||
const versionComponent = useMemo(() => {
|
||||
if (!isInvocationNodeData(data) || !nodeTemplate) {
|
||||
if (!isInvocationNode(node) || !nodeTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!data.version) {
|
||||
if (!node.data.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.versionUnknown')}
|
||||
@@ -61,35 +61,35 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
if (!nodeTemplate.version) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.unknownTemplate')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.unknownTemplate')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '<')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '<')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateNode')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateNode')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
if (compare(data.version, nodeTemplate.version, '>')) {
|
||||
if (compare(node.data.version, nodeTemplate.version, '>')) {
|
||||
return (
|
||||
<Text as="span" color="error.500">
|
||||
{t('nodes.version')} {data.version} ({t('nodes.updateApp')})
|
||||
{t('nodes.version')} {node.data.version} ({t('nodes.updateApp')})
|
||||
</Text>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Text as="span">
|
||||
{t('nodes.version')} {data.version}
|
||||
{t('nodes.version')} {node.data.version}
|
||||
</Text>
|
||||
);
|
||||
}, [data, nodeTemplate, t]);
|
||||
}, [node, nodeTemplate, t]);
|
||||
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return <Text fontWeight="semibold">{t('nodes.unknownNode')}</Text>;
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ const TooltipContent = memo(({ nodeId }: { nodeId: string }) => {
|
||||
{nodeTemplate?.description}
|
||||
</Text>
|
||||
{versionComponent}
|
||||
{data?.notes && <Text>{data.notes}</Text>}
|
||||
{node.data?.notes && <Text>{node.data.notes}</Text>}
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import { FormControl, FormLabel, Textarea } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useNodeData } from 'features/nodes/hooks/useNodeData';
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { nodeNotesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { isInvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const data = useNodeData(nodeId);
|
||||
const node = useNode(nodeId);
|
||||
const { t } = useTranslation();
|
||||
const handleNotesChanged = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
@@ -17,13 +17,13 @@ const NotesTextarea = ({ nodeId }: { nodeId: string }) => {
|
||||
},
|
||||
[dispatch, nodeId]
|
||||
);
|
||||
if (!isInvocationNodeData(data)) {
|
||||
if (!isInvocationNode(node)) {
|
||||
return null;
|
||||
}
|
||||
return (
|
||||
<FormControl orientation="vertical" h="full">
|
||||
<FormLabel>{t('nodes.notes')}</FormLabel>
|
||||
<Textarea value={data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
<Textarea value={node.data?.notes} onChange={handleNotesChanged} rows={10} resize="none" />
|
||||
</FormControl>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -23,6 +23,8 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const { isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim } =
|
||||
useConnectionState({ nodeId, fieldName, kind: 'inputs' });
|
||||
|
||||
console.log({ isConnected, isConnectionInProgress, isConnectionStartField, validationResult, shouldDim });
|
||||
|
||||
const isMissingInput = useMemo(() => {
|
||||
if (!fieldTemplate) {
|
||||
return false;
|
||||
|
||||
@@ -64,7 +64,7 @@ type OutputFieldWrapperProps = PropsWithChildren<{
|
||||
shouldDim: boolean;
|
||||
}>;
|
||||
|
||||
const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
|
||||
export const OutputFieldWrapper = memo(({ shouldDim, children }: OutputFieldWrapperProps) => (
|
||||
<Flex
|
||||
position="relative"
|
||||
minH={8}
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useStore } from '@nanostores/react';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildBatchInputImageNode } from 'features/nodes/util/node/buildBatchInputImageNode';
|
||||
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
|
||||
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
|
||||
|
||||
return useCallback(
|
||||
// string here is "any invocation type"
|
||||
(type: string | 'current_image' | 'notes'): AnyNode => {
|
||||
(type: string | 'current_image' | 'notes' | 'batch_input'): AnyNode => {
|
||||
let _x = window.innerWidth / 2;
|
||||
let _y = window.innerHeight / 2;
|
||||
|
||||
@@ -39,6 +40,10 @@ export const useBuildNode = () => {
|
||||
return buildNotesNode(position);
|
||||
}
|
||||
|
||||
if (type === 'batch_input') {
|
||||
return buildBatchInputImageNode(position);
|
||||
}
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = templates[type] as InvocationTemplate;
|
||||
|
||||
@@ -12,6 +12,11 @@ import {
|
||||
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import {
|
||||
type FieldInputTemplate,
|
||||
type FieldOutputTemplate,
|
||||
imageBatchOutputFieldTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import type { EdgeChange, OnConnect, OnConnectEnd, OnConnectStart } from 'reactflow';
|
||||
import { useUpdateNodeInternals } from 'reactflow';
|
||||
@@ -33,13 +38,20 @@ export const useConnection = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
let fieldTemplate: FieldInputTemplate | FieldOutputTemplate | undefined = undefined;
|
||||
|
||||
if (node.type === 'image_batch' && handleId === 'images' && handleType === 'source') {
|
||||
fieldTemplate = imageBatchOutputFieldTemplate;
|
||||
} else {
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
fieldTemplate = fieldTemplates[handleId];
|
||||
}
|
||||
|
||||
const fieldTemplates = template[handleType === 'source' ? 'outputs' : 'inputs'];
|
||||
const fieldTemplate = fieldTemplates[handleId];
|
||||
if (!fieldTemplate) {
|
||||
return;
|
||||
}
|
||||
|
||||
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
19
invokeai/frontend/web/src/features/nodes/hooks/useNode.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { selectNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { useMemo } from 'react';
|
||||
import type { Node } from 'reactflow';
|
||||
|
||||
export const useNode = (nodeId: string): Node => {
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createMemoizedSelector(selectNodesSlice, (nodes) => {
|
||||
return selectNode(nodes, nodeId);
|
||||
}),
|
||||
[nodeId]
|
||||
);
|
||||
|
||||
const node = useAppSelector(selector);
|
||||
|
||||
return node;
|
||||
};
|
||||
@@ -3,6 +3,7 @@ import { createSlice, isAnyOf } from '@reduxjs/toolkit';
|
||||
import type { PersistConfig } from 'app/store/store';
|
||||
import { buildUseBoolean } from 'common/hooks/useBoolean';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
@@ -58,7 +59,8 @@ import {
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { isImageBatchNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { atom, computed } from 'nanostores';
|
||||
import type { MouseEvent } from 'react';
|
||||
import type { Edge, EdgeChange, NodeChange, Viewport, XYPosition } from 'reactflow';
|
||||
@@ -193,7 +195,7 @@ export const nodesSlice = createSlice({
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node) && !isNotesNode(node)) {
|
||||
if (!isInvocationNode(node) && !isNotesNode(node) && !isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -382,6 +384,34 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = value;
|
||||
},
|
||||
batchImageInputNodeImagesAdded: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
|
||||
const { nodeId, images } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.images = uniqBy([...node.data.images, ...images], 'image_name');
|
||||
},
|
||||
batchImageInputNodeImagesRemoved: (state, action: PayloadAction<{ nodeId: string; images: ImageField[] }>) => {
|
||||
const { nodeId, images } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
const imageNamesToRemove = images.map(({ image_name }) => image_name);
|
||||
node.data.images = uniqBy(
|
||||
node.data.images.filter(({ image_name }) => imageNamesToRemove.includes(image_name)),
|
||||
'image_name'
|
||||
);
|
||||
},
|
||||
batchImageInputNodeReset: (state, action: PayloadAction<{ nodeId: string }>) => {
|
||||
const { nodeId } = action.payload;
|
||||
const node = state.nodes.find((n) => n.id === nodeId);
|
||||
if (!isImageBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
node.data.images = [];
|
||||
},
|
||||
nodeEditorReset: (state) => {
|
||||
state.nodes = [];
|
||||
state.edges = [];
|
||||
@@ -443,6 +473,9 @@ export const {
|
||||
notesNodeValueChanged,
|
||||
undo,
|
||||
redo,
|
||||
batchImageInputNodeImagesAdded,
|
||||
batchImageInputNodeImagesRemoved,
|
||||
batchImageInputNodeReset,
|
||||
} = nodesSlice.actions;
|
||||
|
||||
export const $cursorPos = atom<XYPosition | null>(null);
|
||||
|
||||
@@ -5,8 +5,15 @@ import type { NodesState } from 'features/nodes/store/types';
|
||||
import type { FieldInputInstance } from 'features/nodes/types/field';
|
||||
import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { Node } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
export const selectNode = (nodesSlice: NodesState, nodeId: string): Node => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(node !== undefined, `Node ${nodeId} not found`);
|
||||
return node;
|
||||
};
|
||||
|
||||
export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => {
|
||||
const node = nodesSlice.nodes.find((node) => node.id === nodeId);
|
||||
assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`);
|
||||
|
||||
@@ -3,6 +3,7 @@ import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import { imageBatchOutputFieldTemplate } from 'features/nodes/types/field';
|
||||
import type { AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { Connection as NullableConnection, Edge } from 'reactflow';
|
||||
import type { SetNonNullable } from 'type-fest';
|
||||
@@ -78,23 +79,32 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
if (sourceNode.type === 'image_batch') {
|
||||
const isValid = validateConnectionTypes(imageBatchOutputFieldTemplate.type, targetFieldTemplate.type);
|
||||
if (!isValid) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
} else {
|
||||
return buildAcceptResult();
|
||||
}
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,10 @@ export const validateConnectionTypes = (sourceType: FieldType, targetType: Field
|
||||
return false;
|
||||
}
|
||||
|
||||
if (sourceType.name === 'ImageBatchField') {
|
||||
return isSingle(sourceType) && targetType.name === 'ImageField';
|
||||
}
|
||||
|
||||
if (areTypesEqual(sourceType, targetType)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -59,6 +59,7 @@ export const FIELD_COLORS: { [key: string]: string } = {
|
||||
EnumField: 'blue.500',
|
||||
FloatField: 'orange.500',
|
||||
ImageField: 'purple.500',
|
||||
ImageBatchField: 'purple.500',
|
||||
IntegerField: 'red.500',
|
||||
IPAdapterField: 'teal.500',
|
||||
IPAdapterModelField: 'teal.500',
|
||||
|
||||
@@ -95,6 +95,10 @@ const zImageFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zImageBatchFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('ImageBatchField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -175,13 +179,14 @@ const zSchedulerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('SchedulerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStatefulFieldType = z.union([
|
||||
export const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
zStringFieldType,
|
||||
zBooleanFieldType,
|
||||
zEnumFieldType,
|
||||
zImageFieldType,
|
||||
zImageBatchFieldType,
|
||||
zBoardFieldType,
|
||||
zModelIdentifierFieldType,
|
||||
zMainModelFieldType,
|
||||
@@ -367,6 +372,19 @@ export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputI
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
|
||||
const zImageBatchFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zImageBatchFieldType,
|
||||
});
|
||||
export type ImageBatchOutputFieldTemplate = z.infer<typeof zImageBatchFieldOutputTemplate>;
|
||||
|
||||
export const imageBatchOutputFieldTemplate: ImageBatchOutputFieldTemplate = {
|
||||
fieldKind: 'output',
|
||||
name: 'images',
|
||||
ui_hidden: false,
|
||||
type: { name: 'ImageBatchField', cardinality: 'SINGLE' },
|
||||
title: 'Image',
|
||||
};
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
@@ -991,6 +1009,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
zImageBatchFieldOutputTemplate,
|
||||
zBoardFieldOutputTemplate,
|
||||
zModelIdentifierFieldOutputTemplate,
|
||||
zMainModelFieldOutputTemplate,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import type { Edge, Node } from 'reactflow';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zClassification, zProgressImage } from './common';
|
||||
import { zClassification, zImageField, zProgressImage } from './common';
|
||||
import { zFieldInputInstance, zFieldInputTemplate, zFieldOutputTemplate } from './field';
|
||||
import { zSemVer } from './semver';
|
||||
|
||||
@@ -49,23 +49,33 @@ const zCurrentImageNodeData = z.object({
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData]);
|
||||
const zImageBatchNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('image_batch'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
images: z.array(zImageField),
|
||||
});
|
||||
const zAnyNodeData = z.union([zInvocationNodeData, zNotesNodeData, zCurrentImageNodeData, zImageBatchNodeData]);
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
export type ImageBatchNodeData = z.infer<typeof zImageBatchNodeData>;
|
||||
|
||||
type AnyNodeData = z.infer<typeof zAnyNodeData>;
|
||||
|
||||
export type InvocationNode = Node<InvocationNodeData, 'invocation'>;
|
||||
export type NotesNode = Node<NotesNodeData, 'notes'>;
|
||||
export type CurrentImageNode = Node<CurrentImageNodeData, 'current_image'>;
|
||||
export type BatchImageInputNode = Node<ImageBatchNodeData, 'image_batch'>;
|
||||
export type AnyNode = Node<AnyNodeData>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isInvocationNodeData = (node?: AnyNodeData | null): node is InvocationNodeData =>
|
||||
Boolean(node && !['notes', 'current_image'].includes(node.type)); // node.type may be 'notes', 'current_image', or any invocation type
|
||||
export const isImageBatchNode = (node?: AnyNode | null): node is BatchImageInputNode =>
|
||||
Boolean(node && node.type === 'image_batch');
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
const log = logger('workflows');
|
||||
|
||||
/**
|
||||
* Builds a graph from the node editor state.
|
||||
*/
|
||||
@@ -47,22 +50,31 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
return nodesAccumulator;
|
||||
}, {});
|
||||
|
||||
const filteredNodeIds = filteredNodes.map(({ id }) => id);
|
||||
|
||||
// skip out the "dummy" edges between collapsed nodes
|
||||
const filteredEdges = edges.filter((n) => n.type !== 'collapsed');
|
||||
const filteredEdges = edges
|
||||
.filter((edge) => edge.type !== 'collapsed')
|
||||
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
|
||||
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
log.warn({ source, target, sourceHandle, targetHandle }, 'Missing source or taget handle for edge');
|
||||
return edgesAccumulator;
|
||||
}
|
||||
|
||||
// Format the edges and add to the edges array
|
||||
edgesAccumulator.push({
|
||||
source: {
|
||||
node_id: source,
|
||||
field: sourceHandle as string,
|
||||
field: sourceHandle,
|
||||
},
|
||||
destination: {
|
||||
node_id: target,
|
||||
field: targetHandle as string,
|
||||
field: targetHandle,
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type { BatchImageInputNode } from 'features/nodes/types/invocation';
|
||||
import type { XYPosition } from 'reactflow';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const buildBatchInputImageNode = (position: XYPosition): BatchImageInputNode => {
|
||||
const nodeId = uuidv4();
|
||||
const node: BatchImageInputNode = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'image_batch',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
isOpen: true,
|
||||
label: 'Image Batch',
|
||||
type: 'image_batch',
|
||||
images: [],
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
@@ -1822,7 +1822,7 @@ export type components = {
|
||||
* Items
|
||||
* @description The list of items to substitute into the node/field.
|
||||
*/
|
||||
items?: (string | number)[];
|
||||
items?: (string | number | components["schemas"]["ImageField"])[];
|
||||
};
|
||||
/**
|
||||
* BatchEnqueuedEvent
|
||||
@@ -6751,6 +6751,12 @@ export type components = {
|
||||
* @default 1
|
||||
*/
|
||||
denoising_end?: number;
|
||||
/**
|
||||
* Add Noise
|
||||
* @description Add noise based on denoising start.
|
||||
* @default true
|
||||
*/
|
||||
add_noise?: boolean;
|
||||
/**
|
||||
* Transformer
|
||||
* @description Flux model (Transformer) to load
|
||||
@@ -13161,7 +13167,7 @@ export type components = {
|
||||
* Value
|
||||
* @description The value to substitute into the node/field.
|
||||
*/
|
||||
value: string | number;
|
||||
value: string | number | components["schemas"]["ImageField"];
|
||||
};
|
||||
/**
|
||||
* Noise
|
||||
|
||||
Reference in New Issue
Block a user