From dd056067a9ac23b530ab9f6ba8c410cf252b092d Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:03:59 -0500 Subject: [PATCH] Added workflow connectors (#9027) * Add persisted workflow connectors * Polish workflow connector menu and visuals * Refine connector sizing and alignment * Fix connector deletion and unresolved constraints * Format connector deletion tests * Revert frontend test config tweaks * Address workflow connector review feedback * Format workflow flow menu memo --- invokeai/frontend/web/public/locales/en.json | 2 + .../features/nodes/components/flow/Flow.tsx | 347 +++++++++++++++++- .../flow/edges/util/buildEdgeSelectors.ts | 28 +- .../flow/nodes/Connector/ConnectorNode.tsx | 97 +++++ .../nodes/common/NonInvocationNodeWrapper.tsx | 11 +- .../components/flow/nodes/common/shared.ts | 8 +- .../src/features/nodes/hooks/useBuildNode.ts | 7 +- .../src/features/nodes/hooks/useConnection.ts | 47 +++ .../features/nodes/store/nodesSlice.test.ts | 160 ++++++++ .../src/features/nodes/store/nodesSlice.ts | 78 +++- .../store/util/connectorTopology.test.ts | 126 +++++++ .../nodes/store/util/connectorTopology.ts | 228 ++++++++++++ .../util/getFirstValidConnection.test.ts | 112 +++++- .../store/util/getFirstValidConnection.ts | 81 +++- .../nodes/store/util/reactFlowUtil.test.ts | 23 ++ .../nodes/store/util/reactFlowUtil.ts | 1 + .../store/util/validateConnection.test.ts | 225 ++++++++++++ .../nodes/store/util/validateConnection.ts | 316 +++++++++++++--- .../src/features/nodes/types/invocation.ts | 20 +- .../web/src/features/nodes/types/workflow.ts | 12 +- .../nodes/util/graph/buildNodesGraph.test.ts | 194 ++++++++++ .../nodes/util/graph/buildNodesGraph.ts | 56 ++- .../nodes/util/node/buildConnectorNode.ts | 21 ++ .../nodes/util/workflow/buildWorkflow.ts | 5 +- .../nodes/util/workflow/graphToWorkflow.ts | 2 +- .../nodes/util/workflow/migrations.ts | 12 +- .../util/workflow/validateWorkflow.test.ts | 149 +++++++- .../nodes/util/workflow/validateWorkflow.ts | 20 +- 28 files changed, 2260 insertions(+), 128 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 201ea8badb..7fad98531e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1417,6 +1417,8 @@ "notes": "Notes", "description": "Description", "notesDescription": "Add notes about your workflow", + "addConnector": "Add Connector", + "deleteConnector": "Delete Connector", "problemSettingTitle": "Problem Setting Title", "resetToDefaultValue": "Reset to default value", "reloadNodeTemplates": "Reload Node Templates", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 0c48eddfc6..2fc5e12384 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,4 +1,4 @@ -import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; +import { Menu, MenuButton, MenuItem, MenuList, Portal, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import type { EdgeChange, @@ -32,7 +32,9 @@ import { $edgePendingUpdate, $lastEdgeUpdateMouseEvent, $pendingConnection, + $templates, $viewport, + connectorInserted, edgesChanged, nodesChanged, redo, @@ -46,18 +48,24 @@ import { selectNodes, selectNodesSlice, } from 'features/nodes/store/selectors'; +import { getConnectorDeletionSpliceConnections } from 'features/nodes/store/util/connectorTopology'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice'; import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import type { CSSProperties, MouseEvent, RefObject } from 'react'; -import { memo, useCallback, useMemo, useRef } from 'react'; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { PiPlugsConnectedBold, PiTrashBold } from 'react-icons/pi'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; import InvocationDefaultEdge from './edges/InvocationDefaultEdge'; +import ConnectorNode from './nodes/Connector/ConnectorNode'; import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; @@ -70,6 +78,7 @@ const edgeTypes = { const nodeTypes = { invocation: InvocationNodeWrapper, + connector: ConnectorNode, current_image: CurrentImageNode, notes: NotesNode, } as const; @@ -81,17 +90,69 @@ const snapGrid: [number, number] = [25, 25]; const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection; +type WorkflowContextMenuState = + | { + kind: 'pane'; + clientX: number; + clientY: number; + pageX: number; + pageY: number; + } + | { + kind: 'connector'; + connectorId: string; + pageX: number; + pageY: number; + } + | null; + +const getWorkflowContextMenuState = ( + event: globalThis.MouseEvent, + flowWrapper: HTMLDivElement | null +): WorkflowContextMenuState => { + if (event.shiftKey || !(event.target instanceof Element) || !flowWrapper?.contains(event.target)) { + return null; + } + + const connectorId = event.target.closest('[data-connector-node-id]')?.dataset.connectorNodeId; + if (connectorId) { + return { + kind: 'connector', + connectorId, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + const paneTarget = event.target.closest('.react-flow__pane'); + if (paneTarget && flowWrapper.contains(paneTarget)) { + return { + kind: 'pane', + clientX: event.clientX, + clientY: event.clientY, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + return null; +}; + export const Flow = memo(() => { + const { t } = useTranslation(); const dispatch = useAppDispatch(); const nodes = useAppSelector(selectNodes); const edges = useAppSelector(selectEdges); + const templates = useStore($templates); const viewport = useStore($viewport); const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid); const selectionMode = useAppSelector(selectSelectionMode); const { onConnectStart, onConnect, onConnectEnd } = useConnection(); const flowWrapper = useRef(null); + const pendingNodeInternalsUpdateRef = useRef(null); const isValidConnection = useIsValidConnection(); const updateNodeInternals = useUpdateNodeInternals(); + const [contextMenuState, setContextMenuState] = useState(null); useFocusRegion('workflows', flowWrapper); @@ -127,9 +188,16 @@ export const Flow = memo(() => { }, []); const { onCloseGlobal } = useGlobalMenuClose(); - const handlePaneClick = useCallback(() => { - onCloseGlobal(); - }, [onCloseGlobal]); + const handlePaneClick: NonNullable['onPaneClick']> = useCallback( + (event) => { + if ('button' in event && event.button !== 0) { + return; + } + onCloseGlobal(); + setContextMenuState(null); + }, + [onCloseGlobal] + ); const onInit: OnInit = useCallback((flow) => { $flow.set(flow); @@ -147,6 +215,149 @@ export const Flow = memo(() => { } }, []); + useEffect(() => { + const pendingNodeIds = pendingNodeInternalsUpdateRef.current; + if (!pendingNodeIds) { + return; + } + + pendingNodeInternalsUpdateRef.current = null; + + const frameId = requestAnimationFrame(() => { + updateNodeInternals([...new Set(pendingNodeIds)]); + }); + + return () => { + cancelAnimationFrame(frameId); + }; + }, [edges, nodes, updateNodeInternals]); + + const addConnectorAtPaneMenuPosition = useCallback(() => { + if (contextMenuState?.kind !== 'pane') { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: contextMenuState.clientX, + y: contextMenuState.clientY, + }) + ); + dispatch(nodesChanged([{ type: 'add', item: connector }])); + setContextMenuState(null); + }, [contextMenuState, dispatch]); + + const connectorSpliceConnections = useMemo( + () => + contextMenuState?.kind === 'connector' + ? getConnectorDeletionSpliceConnections( + contextMenuState.connectorId, + nodes, + edges, + templates, + validateConnection + ) + : null, + [contextMenuState, edges, nodes, templates] + ); + + const deleteConnectorFromContextMenu = useCallback(() => { + if (contextMenuState?.kind !== 'connector' || !connectorSpliceConnections) { + return; + } + const connectorEdgeRemovals: EdgeChange[] = edges + .filter((edge) => edge.source === contextMenuState.connectorId || edge.target === contextMenuState.connectorId) + .map((edge) => ({ type: 'remove', id: edge.id })); + const spliceEdgeAdditions: EdgeChange[] = connectorSpliceConnections.map((connection) => ({ + type: 'add', + item: connectionToEdge(connection), + })); + + pendingNodeInternalsUpdateRef.current = [ + contextMenuState.connectorId, + ...connectorSpliceConnections.flatMap((connection) => [connection.source, connection.target]), + ]; + dispatch(edgesChanged([...connectorEdgeRemovals, ...spliceEdgeAdditions])); + dispatch(nodesChanged([{ type: 'remove', id: contextMenuState.connectorId }])); + setContextMenuState(null); + }, [connectorSpliceConnections, contextMenuState, dispatch, edges]); + + useEffect(() => { + const onWindowContextMenu = (event: globalThis.MouseEvent) => { + const nextContextMenuState = getWorkflowContextMenuState(event, flowWrapper.current); + if (!nextContextMenuState) { + return; + } + + event.preventDefault(); + event.stopPropagation(); + setContextMenuState(nextContextMenuState); + }; + + window.addEventListener('contextmenu', onWindowContextMenu, { capture: true }); + + return () => { + window.removeEventListener('contextmenu', onWindowContextMenu, { capture: true }); + }; + }, []); + + const renderContextMenu = useCallback(() => { + if (contextMenuState?.kind === 'pane') { + return ( + + } onClick={addConnectorAtPaneMenuPosition}> + {t('nodes.addConnector')} + + + ); + } + + if (contextMenuState?.kind === 'connector') { + return ( + + } + onClick={deleteConnectorFromContextMenu} + isDisabled={!connectorSpliceConnections} + isDestructive + > + {t('nodes.deleteConnector')} + + + ); + } + + return ; + }, [addConnectorAtPaneMenuPosition, connectorSpliceConnections, contextMenuState, deleteConnectorFromContextMenu, t]); + + const closeContextMenu = useCallback(() => { + setContextMenuState(null); + }, []); + + const onEdgeDoubleClick = useCallback>( + (event, edge) => { + if (edge.type !== 'default' || edge.hidden) { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: event.clientX, + y: event.clientY, + }) + ); + dispatch(connectorInserted({ edgeId: edge.id, connector })); + updateNodeInternals([edge.source, edge.target, connector.id]); + }, + [dispatch, updateNodeInternals] + ); + // #region Updatable Edges /** @@ -209,16 +420,130 @@ export const Flow = memo(() => { // #endregion + const renderedNodes = useMemo(() => nodes, [nodes]); + + const renderedEdges = useMemo(() => edges, [edges]); + const contextMenuPosition = contextMenuState ? { x: contextMenuState.pageX, y: contextMenuState.pageY } : null; + const contextMenuKey = contextMenuPosition ? `${contextMenuPosition.x}-${contextMenuPosition.y}` : 'closed'; + return ( <> + + + + + {renderContextMenu()} + + + + + ); +}); + +Flow.displayName = 'Flow'; + +type FlowSurfaceProps = { + flowWrapper: { current: HTMLDivElement | null }; + viewport: ReactFlowProps['defaultViewport']; + renderedNodes: AnyNode[]; + renderedEdges: AnyEdge[]; + onInit: OnInit; + onMouseMove: (event: MouseEvent) => void; + onNodesChange: OnNodesChange; + onEdgesChange: OnEdgesChange; + onReconnect: OnReconnect; + onReconnectStart: NonNullable['onReconnectStart']>; + onReconnectEnd: NonNullable['onReconnectEnd']>; + onConnectStart: NonNullable['onConnectStart']>; + onConnect: NonNullable['onConnect']>; + onConnectEnd: NonNullable['onConnectEnd']>; + handleMoveEnd: OnMoveEnd; + onEdgeDoubleClick: NonNullable['onEdgeDoubleClick']>; + isValidConnection: NonNullable['isValidConnection']>; + shouldSnapToGrid: boolean; + flowStyles: CSSProperties; + handlePaneClick: NonNullable['onPaneClick']>; + selectionMode: ReturnType; +}; + +const FlowSurface = memo((props: FlowSurfaceProps) => { + const { + flowWrapper, + viewport, + renderedNodes, + renderedEdges, + onInit, + onMouseMove, + onNodesChange, + onEdgesChange, + onReconnect, + onReconnectStart, + onReconnectEnd, + onConnectStart, + onConnect, + onConnectEnd, + handleMoveEnd, + onEdgeDoubleClick, + isValidConnection, + shouldSnapToGrid, + flowStyles, + handlePaneClick, + selectionMode, + } = props; + + const setFlowWrapperElement = useCallback( + (el: HTMLDivElement | null) => { + flowWrapper.current = el; + }, + [flowWrapper] + ); + + return ( +
id="workflow-editor" - ref={flowWrapper} defaultViewport={viewport} nodeTypes={nodeTypes} edgeTypes={edgeTypes} - nodes={nodes} - edges={edges} + nodes={renderedNodes} + edges={renderedEdges} onInit={onInit} onMouseMove={onMouseMove} onNodesChange={onNodesChange} @@ -230,6 +555,7 @@ export const Flow = memo(() => { onConnect={onConnect} onConnectEnd={onConnectEnd} onMoveEnd={handleMoveEnd} + onEdgeDoubleClick={onEdgeDoubleClick} connectionLineComponent={CustomConnectionLine} isValidConnection={isValidConnection} minZoom={0.1} @@ -249,12 +575,11 @@ export const Flow = memo(() => { > - - +
); }); -Flow.displayName = 'Flow'; +FlowSurface.displayName = 'FlowSurface'; const HotkeyIsolator = memo(({ flowWrapper }: { flowWrapper: RefObject }) => { const mayUndo = useAppSelector(selectMayUndo); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts index b64e5a6e6a..40a44a16c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts @@ -1,9 +1,10 @@ import { createSelector } from '@reduxjs/toolkit'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { selectNodes } from 'features/nodes/store/selectors'; +import { selectEdges, selectNodes } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; @@ -22,23 +23,20 @@ export const buildSelectEdgeColor = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => { + createSelector(selectNodes, selectEdges, selectWorkflowSettingsSlice, (nodes, edges, workflowSettings): string => { const { shouldColorEdges } = workflowSettings; if (!shouldColorEdges) { return colorTokenToCssVar('base.500'); } const sourceNode = nodes.find((node) => node.id === source); - const targetNode = nodes.find((node) => node.id === target); - if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { + if (!sourceNode || !sourceHandleId || !targetHandleId) { return colorTokenToCssVar('base.500'); } - const sourceNodeTemplate = templates[sourceNode.data.type]; - - const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; - const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; + const sourceType = isConnectorNode(sourceNode) + ? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates) + : templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type; return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); }); @@ -50,7 +48,7 @@ export const buildSelectEdgeLabel = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, (nodes): string | null => { + createSelector(selectNodes, selectEdges, (nodes, edges): string | null => { const sourceNode = nodes.find((node) => node.id === source); const targetNode = nodes.find((node) => node.id === target); @@ -58,8 +56,12 @@ export const buildSelectEdgeLabel = ( return null; } - const sourceNodeTemplate = templates[sourceNode.data.type]; + const resolvedSource = isConnectorNode(sourceNode) ? resolveConnectorSource(sourceNode.id, nodes, edges) : null; + const sourceTemplate = + resolvedSource !== null + ? templates[nodes.find((node) => node.id === resolvedSource.nodeId)?.data.type ?? ''] + : templates[sourceNode.data.type]; const targetNodeTemplate = templates[targetNode.data.type]; - return `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + return `${sourceTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx new file mode 100644 index 0000000000..a971efb397 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx @@ -0,0 +1,97 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Box, Icon } from '@invoke-ai/ui-library'; +import type { Node, NodeProps } from '@xyflow/react'; +import { Handle, Position } from '@xyflow/react'; +import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { NO_DRAG_CLASS } from 'features/nodes/types/constants'; +import type { ConnectorNodeData } from 'features/nodes/types/invocation'; +import type { CSSProperties } from 'react'; +import { memo } from 'react'; +import { PiDotOutlineFill } from 'react-icons/pi'; + +const handleVisualSx = { + w: 3, + h: 3, + borderRadius: 'full', + borderWidth: 2, + borderColor: 'base.900', + bg: 'base.100', + pointerEvents: 'none', +} satisfies SystemStyleObject; + +const handleStyles = { + position: 'absolute', + width: '1rem', + height: '1rem', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + top: '50%', + transform: 'translateY(-50%)', + zIndex: 1, + background: 'none', + border: 'none', +} satisfies CSSProperties; + +const inputHandleStyles = { + ...handleStyles, + insetInlineStart: 0, + justifyContent: 'flex-start', +} satisfies CSSProperties; + +const outputHandleStyles = { + ...handleStyles, + insetInlineEnd: 0, + justifyContent: 'flex-end', +} satisfies CSSProperties; + +const ConnectorNode = ({ id, selected }: NodeProps>) => { + return ( + + + + + + + + + + + + + + ); +}; + +export default memo(ConnectorNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx index 7e2cde7093..84246ba43b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx @@ -16,10 +16,12 @@ type NonInvocationNodeWrapperProps = PropsWithChildren & { nodeId: string; selected: boolean; width?: ChakraProps['w']; + borderRadius?: ChakraProps['borderRadius']; + withChrome?: boolean; }; const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { - const { nodeId, width, children, selected } = props; + const { nodeId, width, children, selected, borderRadius = 'base', withChrome = true } = props; const mouseOverNode = useMouseOverNode(nodeId); const zoomToNode = useZoomToNode(nodeId); @@ -62,14 +64,15 @@ const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { onMouseOut={mouseOverNode.handleMouseOut} className={DRAG_HANDLE_CLASSNAME} sx={containerSx} + borderRadius={borderRadius} width={width || NODE_WIDTH} opacity={opacity} data-is-selected={selected} > - - + {withChrome && } + {withChrome && } {children} - + {withChrome && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts index 70e56cb4db..15624780c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts @@ -6,7 +6,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; export const containerSx: SystemStyleObject = { h: 'full', position: 'relative', - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', cursor: 'grab', '--border-color': 'var(--invoke-colors-base-500)', @@ -30,7 +30,7 @@ export const containerSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', pointerEvents: 'none', shadow: '0 0 0 1px var(--border-color)', @@ -64,7 +64,7 @@ export const shadowsSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', pointerEvents: 'none', zIndex: -1, shadow: 'var(--invoke-shadows-xl), var(--invoke-shadows-base), var(--invoke-shadows-base)', @@ -76,7 +76,7 @@ export const inProgressSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'md', + borderRadius: 'inherit', pointerEvents: 'none', transitionProperty: 'none', opacity: 0.7, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index 591aac6c18..428fcf1e82 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -3,6 +3,7 @@ import { $templates } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; import { NODE_WIDTH } from 'features/nodes/types/constants'; import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode'; @@ -14,7 +15,7 @@ export const useBuildNode = () => { return useCallback( // string here is "any invocation type" - (type: string | 'current_image' | 'notes'): AnyNode => { + (type: string | 'connector' | 'current_image' | 'notes'): AnyNode => { const flow = $flow.get(); assert(flow !== null); @@ -42,6 +43,10 @@ export const useBuildNode = () => { return buildNotesNode(position); } + if (type === 'connector') { + return buildConnectorNode(position); + } + // TODO: Keep track of invocation types so we do not need to cast this // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. const template = templates[type] as InvocationTemplate; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index bfd8be95f5..d763254bc4 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -12,9 +12,15 @@ import { edgesChanged, } from 'features/nodes/store/nodesSlice'; import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { AnyEdge } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { useCallback, useMemo } from 'react'; import { assert } from 'tsafe'; @@ -33,6 +39,47 @@ export const useConnection = () => { return; } + if (isConnectorNode(node)) { + if (handleType === 'source' && handleId !== CONNECTOR_OUTPUT_HANDLE) { + return; + } + if (handleType === 'target' && handleId !== CONNECTOR_INPUT_HANDLE) { + return; + } + + const resolvedSourceType = + handleType === 'source' + ? resolveConnectorSourceFieldType(nodeId, nodes, selectNodesSlice(store.getState()).edges, templates) + : null; + $pendingConnection.set({ + nodeId, + handleId, + handleType, + fieldTemplate: + handleType === 'source' + ? { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: resolvedSourceType ?? { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + } + : { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + }, + }); + return; + } + const template = templates[node.data.type]; if (!template) { return; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts new file mode 100644 index 0000000000..5306347921 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts @@ -0,0 +1,160 @@ +import { deepClone } from 'common/util/deepClone'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; +import { describe, expect, it } from 'vitest'; + +import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology'; +import { add, buildEdge, buildNode, sub } from './util/testUtils'; + +const buildFixedConnectorNode = (id: string) => { + const connectorNode = buildConnectorNode({ x: 0, y: 0 }); + return { + ...connectorNode, + id, + data: { + ...connectorNode.data, + id, + }, + }; +}; + +describe('nodesSlice connector actions', () => { + it('splits a direct edge into source -> connector -> target edges when inserting a connector', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + const directEdge = buildEdge(source.id, 'value', target.id, 'a'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, target]; + initialState.edges = [directEdge]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + connectorInserted({ + edgeId: directEdge.id, + connector, + }) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id, connector.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]); + }); + + it('splices connector outputs back to the resolved upstream source when removed', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); + + it('splices one connector source back to multiple downstream targets when removed', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, targetA, targetB]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'b'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, targetA.id, targetB.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', targetA.id, 'a'), + buildEdge(source.id, 'value', targetB.id, 'b'), + ]); + }); + + it('does not create any edges when removing a connector with no downstream targets', () => { + const source = buildNode(add); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector]; + initialState.edges = [buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id]); + expect(nextState.edges).toEqual([]); + }); + + it('removes a connector while preserving downstream connector edges in a chained splice case', () => { + const source = buildNode(add); + const connectorA = buildFixedConnectorNode('connector-a'); + const connectorB = buildFixedConnectorNode('connector-b'); + const target = buildNode(sub); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connectorA, connectorB, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connectorA.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, connectorB.id, target.id]); + expect(nextState.edges).toHaveLength(2); + expect(nextState.edges).toEqual( + expect.arrayContaining([ + buildEdge(source.id, 'value', connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]) + ); + }); + + it('splices connector edges when the connector is removed through generic node removal', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index bdab6c1ae3..6713ee8fb4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,6 +20,13 @@ import { reparentElement, } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import { type NodesState, zNodesState } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorOutputEdges, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, @@ -65,8 +72,8 @@ import { zStringGeneratorFieldValue, zStylePresetFieldValue, } from 'features/nodes/types/field'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import type { AnyEdge, AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { BuilderForm, ContainerElement, @@ -103,7 +110,7 @@ export const getInitialWorkflow = (): Omit[]>) => { + const removedConnectorSpliceEdges: AnyEdge[] = action.payload.flatMap((change) => { + if (change.type !== 'remove') { + return []; + } + + const node = state.nodes.find((candidate) => candidate.id === change.id); + if (!isConnectorNode(node)) { + return []; + } + + const resolvedSource = resolveConnectorSource(node.id, state.nodes, state.edges); + if (!resolvedSource) { + return []; + } + + return getConnectorOutputEdges(node.id, state.edges) + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => + connectionToEdge({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + }) + ); + }); + // TODO(psyche): The below TS issue was recently fixed upstream. Need to upgrade @xyflow/react and then we // should be able to remove this cast. // @@ -206,6 +240,12 @@ const slice = createSlice({ if (edgeChanges.length > 0) { state.edges = applyEdgeChanges(edgeChanges, state.edges); } + if (removedConnectorSpliceEdges.length > 0) { + state.edges = applyEdgeChanges( + removedConnectorSpliceEdges.map((edge) => ({ type: 'add', item: edge })), + state.edges + ); + } } const wereNodesRemoved = action.payload.some((change) => change.type === 'remove' || change.type === 'replace'); @@ -396,11 +436,40 @@ const slice = createSlice({ } } }, + connectorInserted: ( + state, + action: PayloadAction<{ + edgeId: string; + connector: ConnectorNode; + }> + ) => { + const { edgeId, connector } = action.payload; + const edge = state.edges.find((candidate) => candidate.id === edgeId); + if (!edge || edge.type !== 'default') { + return; + } + state.nodes.push({ ...SHARED_NODE_PROPERTIES, ...connector } as (typeof state.nodes)[number]); + state.edges = state.edges.filter((candidate) => candidate.id !== edgeId); + state.edges.push( + connectionToEdge({ + source: edge.source, + sourceHandle: edge.sourceHandle ?? null, + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }), + connectionToEdge({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: edge.target, + targetHandle: edge.targetHandle ?? null, + }) + ); + }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); const node = state.nodes?.[nodeIndex]; - if (isInvocationNode(node) || isNotesNode(node)) { + if (isInvocationNode(node) || isNotesNode(node) || isConnectorNode(node)) { node.data.label = label; } }, @@ -614,6 +683,7 @@ export const { nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, + connectorInserted, nodeLabelChanged, nodeNotesChanged, nodesChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts new file mode 100644 index 0000000000..e87ebcde79 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts @@ -0,0 +1,126 @@ +import type { AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { describe, expect, it } from 'vitest'; + +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, + getConnectorInputEdge, + getConnectorOutputEdges, + resolveConnectorSource, + resolveConnectorSourceFieldType, +} from './connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from './testUtils'; + +const buildConnectorNode = (id: string): ConnectorNode => ({ + id, + type: 'connector', + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector', + label: 'Connector', + isOpen: true, + }, +}); + +describe('connectorTopology', () => { + it('resolves the effective upstream source through one connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(resolveConnectorSource(connector.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + expect(resolveConnectorSourceFieldType(connector.id, nodes, edges, templates)).toEqual(add.outputs.value?.type); + }); + + it('resolves the effective upstream source through chained connectors', () => { + const source = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [source, connectorA, connectorB]; + const edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + }); + + it('returns no source or type for an unresolved connector chain', () => { + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [connectorA, connectorB]; + const edges = [buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE)]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toBe(null); + expect(resolveConnectorSourceFieldType(connectorB.id, nodes, edges, templates)).toBe(null); + }); + + it('enumerates multiple outgoing edges for a connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const incoming = buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE); + const outgoingA = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'); + const outgoingB = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'); + const edges = [incoming, outgoingA, outgoingB]; + + expect(getConnectorInputEdge(connector.id, edges)).toEqual(incoming); + expect(getConnectorOutputEdges(connector.id, edges)).toEqual([outgoingA, outgoingB]); + }); + + it('rejects connector deletion splice-through when any downstream target would be invalid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(img_resize); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'image'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); + + it('builds connector deletion splice-through edges when every downstream target remains valid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toEqual([ + { + source: source.id, + sourceHandle: 'value', + target: target.id, + targetHandle: 'a', + }, + ]); + }); + + it('returns no splice-through edges when a connector has downstream targets but no upstream source', () => { + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [connector, target]; + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts new file mode 100644 index 0000000000..e1267763d7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts @@ -0,0 +1,228 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; + +export const CONNECTOR_INPUT_HANDLE = 'in'; +export const CONNECTOR_OUTPUT_HANDLE = 'out'; + +type ResolvedConnectorSource = { + nodeId: string; + fieldName: string; +}; + +type SpliceConnection = { + source: string; + sourceHandle: string; + target: string; + targetHandle: string; +}; + +type SpliceConnectionValidator = ( + connection: SpliceConnection, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + ignoreEdge: AnyEdge | null, + strict?: boolean +) => string | null; + +export const getConnectorInputEdge = (connectorId: string, edges: AnyEdge[]): AnyEdge | undefined => + edges.find( + (edge) => + edge.type === 'default' && + edge.target === connectorId && + edge.targetHandle === CONNECTOR_INPUT_HANDLE && + typeof edge.sourceHandle === 'string' + ); + +export const getConnectorOutputEdges = (connectorId: string, edges: AnyEdge[]): AnyEdge[] => + edges.filter( + (edge) => + edge.type === 'default' && + edge.source === connectorId && + edge.sourceHandle === CONNECTOR_OUTPUT_HANDLE && + typeof edge.targetHandle === 'string' + ); + +export const resolveConnectorSource = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[] +): ResolvedConnectorSource | null => { + const visited = new Set(); + + const resolve = (nodeId: string): ResolvedConnectorSource | null => { + if (visited.has(nodeId)) { + return null; + } + visited.add(nodeId); + + const incomingEdge = getConnectorInputEdge(nodeId, edges); + if (!incomingEdge || incomingEdge.type !== 'default') { + return null; + } + if (typeof incomingEdge.sourceHandle !== 'string') { + return null; + } + + const sourceNode = nodes.find((node) => node.id === incomingEdge.source); + if (!sourceNode) { + return null; + } + + if (isInvocationNode(sourceNode)) { + return { nodeId: sourceNode.id, fieldName: incomingEdge.sourceHandle }; + } + + if (isConnectorNode(sourceNode)) { + return resolve(sourceNode.id); + } + + return null; + }; + + return resolve(connectorId); +}; + +export const resolveConnectorSourceFieldType = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): FieldType | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const sourceNode = nodes.find((node) => node.id === resolvedSource.nodeId); + if (!sourceNode || !isInvocationNode(sourceNode)) { + return null; + } + + const sourceTemplate = templates[sourceNode.data.type]; + return sourceTemplate?.outputs[resolvedSource.fieldName]?.type ?? null; +}; + +export const getConnectorDeletionSpliceConnections = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + validateConnection?: SpliceConnectionValidator +): SpliceConnection[] | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const outputEdges = getConnectorOutputEdges(connectorId, edges); + const spliceConnections = outputEdges + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => ({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + })); + + const deduped = new Set(); + for (const connection of spliceConnections) { + const key = `${connection.source}:${connection.sourceHandle}->${connection.target}:${connection.targetHandle}`; + if (deduped.has(key)) { + return null; + } + deduped.add(key); + } + + if (!validateConnection) { + const sourceType = resolveConnectorSourceFieldType(connectorId, nodes, edges, templates); + if (!sourceType) { + return null; + } + const inputEdgeId = getConnectorInputEdge(connectorId, edges)?.id; + const outputEdgeIds = new Set(outputEdges.map((edge) => edge.id)); + + for (const connection of spliceConnections) { + const targetNode = nodes.find((node) => node.id === connection.target); + if (!targetNode || !isInvocationNode(targetNode)) { + return null; + } + const targetTemplate = templates[targetNode.data.type]; + const targetFieldTemplate = targetTemplate?.inputs[connection.targetHandle]; + if (!targetFieldTemplate) { + return null; + } + + const matchesExistingDirectEdge = edges.some( + (edge) => + edge.type === 'default' && + edge.source === connection.source && + edge.sourceHandle === connection.sourceHandle && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if (matchesExistingDirectEdge) { + return null; + } + + const targetConflictCount = spliceConnections.filter( + (candidate) => candidate.target === connection.target && candidate.targetHandle === connection.targetHandle + ).length; + const existingTargetConflict = edges.some( + (edge) => + edge.type === 'default' && + edge.id !== inputEdgeId && + !outputEdgeIds.has(edge.id) && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if ( + targetFieldTemplate.type.name !== 'CollectionItemField' && + (targetConflictCount > 1 || existingTargetConflict) + ) { + return null; + } + + if ( + sourceType.name !== targetFieldTemplate.type.name && + targetFieldTemplate.type.name !== 'CollectionItemField' + ) { + return null; + } + } + + return spliceConnections; + } + + const ignoredEdgeIds = new Set([ + getConnectorInputEdge(connectorId, edges)?.id, + ...outputEdges.map((edge) => edge.id), + ]); + const existingEdges = edges.filter((edge) => !ignoredEdgeIds.has(edge.id)); + const stagedConnections: SpliceConnection[] = []; + + for (const connection of spliceConnections) { + const stagedEdges = [ + ...existingEdges, + ...stagedConnections.map( + ({ source, sourceHandle, target, targetHandle }) => + ({ + id: `splice-${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'default', + source, + sourceHandle, + target, + targetHandle, + }) satisfies AnyEdge + ), + ]; + if (validateConnection(connection, nodes, stagedEdges, templates, null, true) !== null) { + return null; + } + stagedConnections.push(connection); + } + + return spliceConnections; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts index 288a4f9066..b4374a920c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -1,13 +1,26 @@ import { deepClone } from 'common/util/deepClone'; import { unset } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection, getSourceCandidateFields, getTargetCandidateFields, } from 'features/nodes/store/util/getFirstValidConnection'; -import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; import { describe, expect, it } from 'vitest'; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe('getFirstValidConnection', () => { it('should return null if the pending and candidate nodes are the same node', () => { const n = buildNode(add); @@ -120,6 +133,33 @@ describe('getFirstValidConnection', () => { expect(r).toEqual(null); }); }); + + it('should resolve connector target candidates when connecting an invocation output to a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + expect(getFirstValidConnection(n1.id, 'value', connector.id, null, [n1, connector], [], templates, null)).toEqual({ + source: n1.id, + sourceHandle: 'value', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + }); + + it('should resolve connector source candidates when connecting a connector to a typed invocation input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + expect( + getFirstValidConnection(connector.id, null, n2.id, 'width', [n1, connector, n2], edges, templates, null) + ).toEqual({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: n2.id, + targetHandle: 'width', + }); + }); }); describe('getTargetCandidateFields', () => { @@ -160,6 +200,62 @@ describe('getTargetCandidateFields', () => { const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); }); + it('should return the connector input handle when the target is a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = getTargetCandidateFields(n1.id, 'value', connector.id, [n1, connector], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_INPUT_HANDLE]); + }); + it('should advertise typed target candidates for an unresolved connector output when no downstream constraint exists', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n2], + [], + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['a', 'b']); + }); + it('should only advertise compatible typed target candidates for an unresolved connector output with downstream constraints', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['width', 'height']); + }); + it('should resolve chained connector sources like the direct upstream source', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(img_resize); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = getTargetCandidateFields( + connectorB.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); }); describe('getSourceCandidateFields', () => { @@ -200,4 +296,18 @@ describe('getSourceCandidateFields', () => { const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); }); + it('should return the connector output handle when the source is a connector with a typed upstream source', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [n1, connector, n2], edges, templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); + it('should return a target-constrained connector source candidate when the connector chain is unresolved', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [connector, n2], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 0b5aa17e17..2a466aae44 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -1,9 +1,15 @@ import type { Connection } from '@xyflow/react'; import { map } from 'es-toolkit/compat'; import type { Templates } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; /** * @@ -91,16 +97,43 @@ export const getTargetCandidateFields = ( return []; } - const sourceTemplate = templates[sourceNode.data.type]; + if (isConnectorNode(targetNode)) { + const candidate = { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { + name: 'AnyField', + cardinality: 'SINGLE', + batch: false, + }, + } satisfies FieldInputTemplate; + + const c = { source, sourceHandle, target, targetHandle: candidate.name }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!targetTemplate) { return []; } - const sourceField = sourceTemplate.outputs[sourceHandle]; + if (!isConnectorNode(sourceNode)) { + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return []; + } - if (!sourceField) { - return []; + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } } const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { @@ -127,15 +160,45 @@ export const getSourceCandidateFields = ( return []; } + if (isConnectorNode(sourceNode)) { + const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates); + const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null; + const targetFieldType = targetTemplate?.inputs[targetHandle]?.type; + const candidateType = sourceFieldType ?? targetFieldType; + if (!candidateType) { + return []; + } + + const candidate = { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: candidateType, + } satisfies FieldOutputTemplate; + + const c = { source, sourceHandle: candidate.name, target, targetHandle }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const sourceTemplate = templates[sourceNode.data.type]; - const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!sourceTemplate) { return []; } - const targetField = targetTemplate.inputs[targetHandle]; + if (!isConnectorNode(targetNode)) { + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return []; + } - if (!targetField) { + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + } else if (targetHandle !== CONNECTOR_INPUT_HANDLE) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts new file mode 100644 index 0000000000..b70eda4bda --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts @@ -0,0 +1,23 @@ +import { describe, expect, it } from 'vitest'; + +import { connectionToEdge } from './reactFlowUtil'; + +describe('connectionToEdge', () => { + it('creates a default edge with the expected id and endpoints', () => { + expect( + connectionToEdge({ + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + }) + ).toEqual({ + type: 'default', + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + id: 'reactflow__edge-source-nodevalue-target-nodea', + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts index 67264bf41b..3eaece154f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -24,6 +24,7 @@ export const connectionToEdge = (connection: Connection): AnyEdge => { const { source, sourceHandle, target, targetHandle } = connection; assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); return { + type: 'default', source, sourceHandle, target, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 88eae8484f..29395a1e1d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,6 +3,11 @@ import { set } from 'es-toolkit/compat'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { describe, expect, it } from 'vitest'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, +} from './connectorTopology'; import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; import { validateConnection } from './validateConnection'; @@ -139,6 +144,18 @@ const integerCollectionOutputTemplate: InvocationTemplate = { classification: 'stable', }; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe(validateConnection.name, () => { it('should reject invalid connection to self', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; @@ -458,6 +475,214 @@ describe(validateConnection.name, () => { expect(r).toEqual('nodes.connectionWouldCreateCycle'); }); + describe('connectors', () => { + it('should accept invocation output to connector input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = validateConnection( + { source: n1.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, connector], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject a second input into a connector', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection'); + }); + + it('should accept connector output to invocation input when the upstream type matches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject connector output to invocation input when the upstream type mismatches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should accept unresolved connector output to a typed invocation input as the first downstream constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [connector, n2], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject unresolved connector output when it conflicts with an existing downstream typed constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connecting an incompatible upstream source into a connector with downstream typed constraints', () => { + const source = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + const r = validateConnection( + { source: source.id, sourceHandle: 'vae', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [source, connector, target], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should preserve type information through chained connectors', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connectorB.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject cycles routed through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.connectionWouldCreateCycle'); + }); + + it('should preserve collect item validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', n2.id, 'item'), + buildEdge(n3.id, 'vae', connector.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'item' }, + [n1, n2, n3, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); + }); + + it('should preserve if branch validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const n3 = buildNode(ifTemplate); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n3.id, 'true_input'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' }, + [n1, n2, n3, connector], + edges, + { ...templates, if: ifTemplate }, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connector deletion splice-through when it would duplicate an existing direct edge', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(n1.id, 'value', n2.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, [n1, n2, connector], edges, templates)).toBe(null); + }); + + it('should reject connector deletion splice-through when fan-out would violate a single-input target', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + + expect( + getConnectorDeletionSpliceConnections(connectorA.id, [n1, connectorA, connectorB, n2], edges, templates) + ).toBe(null); + }); + }); + describe('non-strict mode', () => { it('should reject connections from self to self in non-strict mode', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index b342df064b..bb98d472d3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -1,10 +1,17 @@ import type { Connection as NullableConnection } from '@xyflow/react'; import type { Templates } from 'features/nodes/store/types'; import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { SetNonNullable } from 'type-fest'; type Connection = SetNonNullable; @@ -18,6 +25,12 @@ type ValidateConnectionFunc = ( strict?: boolean ) => string | null; +type EffectiveSource = { + node: InvocationNode; + handle: string; + fieldTemplate: NonNullable['outputs'][string]; +}; + const getEqualityPredicate = (c: Connection) => (e: AnyEdge): boolean => { @@ -41,10 +54,7 @@ const isIfInputHandle = (handle: string): handle is (typeof IF_INPUT_HANDLES)[nu return IF_INPUT_HANDLES.includes(handle as (typeof IF_INPUT_HANDLES)[number]); }; -const isSingleCollectionPairOfSameBaseType = ( - firstType: { name: string; cardinality: string; batch: boolean }, - secondType: { name: string; cardinality: string; batch: boolean } -) => { +const isSingleCollectionPairOfSameBaseType = (firstType: FieldType, secondType: FieldType) => { const isSingleToCollection = firstType.cardinality === 'SINGLE' && secondType.cardinality === 'COLLECTION' && firstType.name === secondType.name; const isCollectionToSingle = @@ -52,6 +62,139 @@ const isSingleCollectionPairOfSameBaseType = ( return firstType.batch === secondType.batch && (isSingleToCollection || isCollectionToSingle); }; +const areFieldTypesCompatible = (firstType: FieldType, secondType: FieldType) => + validateConnectionTypes(firstType, secondType) || + validateConnectionTypes(secondType, firstType) || + isSingleCollectionPairOfSameBaseType(firstType, secondType); + +type ConnectorTerminalTargetEdge = AnyEdge & { + type: 'default'; + sourceHandle: string; + targetHandle: string; +}; + +const getConnectorTerminalTargetEdges = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const resolve = (currentConnectorId: string): ConnectorTerminalTargetEdge[] => { + if (visited.has(currentConnectorId)) { + return []; + } + visited.add(currentConnectorId); + + return edges.flatMap((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return []; + } + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + return resolve(targetNode.id); + } + + return [edge as ConnectorTerminalTargetEdge]; + }); + }; + + return resolve(connectorId); +}; + +const getConnectorSubgraphEdgeIds = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const edgeIds = new Set(); + + const visit = (currentConnectorId: string) => { + if (visited.has(currentConnectorId)) { + return; + } + visited.add(currentConnectorId); + + edges.forEach((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return; + } + + edgeIds.add(edge.id); + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + visit(targetNode.id); + } + }); + }; + + visit(connectorId); + return edgeIds; +}; + +const getEffectiveSource = ( + sourceId: string, + sourceHandle: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + const sourceNode = nodes.find((n) => n.id === sourceId); + if (!sourceNode) { + return 'nodes.missingNode'; + } + + if (isConnectorNode(sourceNode)) { + if (sourceHandle !== CONNECTOR_OUTPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource) { + return null; + } + + return getEffectiveSource(resolvedSource.nodeId, resolvedSource.fieldName, nodes, edges, templates); + } + + if (!isInvocationNode(sourceNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceFieldTemplate = sourceTemplate.outputs[sourceHandle]; + if (!sourceFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + return { + node: sourceNode, + handle: sourceHandle, + fieldTemplate: sourceFieldTemplate, + }; +}; + +const getEffectiveSourceForEdge = ( + edge: AnyEdge, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + if (edge.type !== 'default' || typeof edge.sourceHandle !== 'string') { + return null; + } + + return getEffectiveSource(edge.source, edge.sourceHandle, nodes, edges, templates); +}; + /** * Validates a connection between two fields * @returns A translation key for an error if the connection is invalid, otherwise null @@ -83,18 +226,63 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotDuplicateConnection'; } - const sourceNode = nodes.find((n) => n.id === c.source); - if (!sourceNode) { - return 'nodes.missingNode'; - } - const targetNode = nodes.find((n) => n.id === c.target); + const sourceNode = nodes.find((n) => n.id === c.source); if (!targetNode) { return 'nodes.missingNode'; } - const sourceTemplate = templates[sourceNode.data.type]; - if (!sourceTemplate) { + const effectiveSource = getEffectiveSource(c.source, c.sourceHandle, nodes, filteredEdges, templates); + if (effectiveSource === 'nodes.missingNode') { + return 'nodes.missingNode'; + } + if (effectiveSource === 'nodes.missingInvocationTemplate') { + return 'nodes.missingInvocationTemplate'; + } + if (effectiveSource === 'nodes.missingFieldTemplate') { + return 'nodes.missingFieldTemplate'; + } + + if (isConnectorNode(targetNode)) { + if (c.targetHandle !== CONNECTOR_INPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + return 'nodes.inputMayOnlyHaveOneConnection'; + } + + if (effectiveSource) { + const connectorSubgraphEdgeIds = getConnectorSubgraphEdgeIds(targetNode.id, nodes, filteredEdges); + const stagedEdges = filteredEdges.filter((edge) => !connectorSubgraphEdgeIds.has(edge.id)); + const terminalTargetEdges = getConnectorTerminalTargetEdges(targetNode.id, nodes, filteredEdges); + + for (const terminalTargetEdge of terminalTargetEdges) { + const downstreamValidation = validateConnection( + { + source: c.source, + sourceHandle: c.sourceHandle, + target: terminalTargetEdge.target, + targetHandle: terminalTargetEdge.targetHandle, + }, + nodes, + stagedEdges, + templates, + null, + true + ); + + if (downstreamValidation !== null) { + return downstreamValidation; + } + } + } + + // Unresolved connector chains are allowed to terminate on another connector. + return null; + } + + if (!isInvocationNode(targetNode)) { return 'nodes.missingInvocationTemplate'; } @@ -103,11 +291,6 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.missingInvocationTemplate'; } - const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; - if (!sourceFieldTemplate) { - return 'nodes.missingFieldTemplate'; - } - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; if (!targetFieldTemplate) { return 'nodes.missingFieldTemplate'; @@ -117,23 +300,58 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotConnectToDirectInput'; } + if (!effectiveSource) { + if (sourceNode && isConnectorNode(sourceNode) && c.sourceHandle === CONNECTOR_OUTPUT_HANDLE) { + const existingTerminalTargetEdges = getConnectorTerminalTargetEdges(sourceNode.id, nodes, filteredEdges).filter( + (edge) => !(edge.target === c.target && edge.targetHandle === c.targetHandle) + ); + + for (const terminalTargetEdge of existingTerminalTargetEdges) { + const constrainedTargetNode = nodes.find((node) => node.id === terminalTargetEdge.target); + if (!constrainedTargetNode || !isInvocationNode(constrainedTargetNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetTemplate = templates[constrainedTargetNode.data.type]; + if (!constrainedTargetTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetFieldTemplate = constrainedTargetTemplate.inputs[terminalTargetEdge.targetHandle]; + if (!constrainedTargetFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + if (!areFieldTypesCompatible(constrainedTargetFieldTemplate.type, targetFieldTemplate.type)) { + return 'nodes.fieldTypesMustMatch'; + } + } + + return null; + } + + return 'nodes.fieldTypesMustMatch'; + } + + const { node: resolvedSourceNode, handle: sourceHandle, fieldTemplate: sourceFieldTemplate } = effectiveSource; + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { // Collect nodes shouldn't mix and match field types. - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const collectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { return 'nodes.cannotMixAndMatchCollectionItemTypes'; } } if ( - sourceNode.data.type === 'collect' && - c.sourceHandle === 'collection' && + resolvedSourceNode.data.type === 'collect' && + sourceHandle === 'collection' && targetNode.data.type === 'collect' && c.targetHandle === 'collection' ) { // Chained collect nodes should preserve a single item type when both ends are already typed. - const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id); - const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const sourceCollectItemType = getCollectItemType(templates, nodes, filteredEdges, resolvedSourceNode.id); + const targetCollectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if ( sourceCollectItemType && targetCollectItemType && @@ -148,33 +366,24 @@ export const validateConnection: ValidateConnectionFunc = ( const siblingInputEdge = filteredEdges.find((e) => e.target === c.target && e.targetHandle === siblingHandle); if (siblingInputEdge) { - if (siblingInputEdge.source === null || siblingInputEdge.source === undefined) { + const siblingEffectiveSource = getEffectiveSourceForEdge(siblingInputEdge, nodes, filteredEdges, templates); + if (siblingEffectiveSource === 'nodes.missingNode') { return 'nodes.missingNode'; } - - if (siblingInputEdge.sourceHandle === null || siblingInputEdge.sourceHandle === undefined) { - return 'nodes.missingFieldTemplate'; - } - - const siblingSourceNode = nodes.find((n) => n.id === siblingInputEdge.source); - if (!siblingSourceNode) { - return 'nodes.missingNode'; - } - - const siblingSourceTemplate = templates[siblingSourceNode.data.type]; - if (!siblingSourceTemplate) { + if (siblingEffectiveSource === 'nodes.missingInvocationTemplate') { return 'nodes.missingInvocationTemplate'; } - - const siblingSourceFieldTemplate = siblingSourceTemplate.outputs[siblingInputEdge.sourceHandle]; - if (!siblingSourceFieldTemplate) { + if (siblingEffectiveSource === 'nodes.missingFieldTemplate') { return 'nodes.missingFieldTemplate'; } + if (!siblingEffectiveSource) { + return 'nodes.fieldTypesMustMatch'; + } const areIfInputTypesCompatible = - validateConnectionTypes(sourceFieldTemplate.type, siblingSourceFieldTemplate.type) || - validateConnectionTypes(siblingSourceFieldTemplate.type, sourceFieldTemplate.type) || - isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingSourceFieldTemplate.type); + validateConnectionTypes(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type) || + validateConnectionTypes(siblingEffectiveSource.fieldTemplate.type, sourceFieldTemplate.type) || + isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type); if (!areIfInputTypesCompatible) { return 'nodes.fieldTypesMustMatch'; @@ -189,30 +398,17 @@ export const validateConnection: ValidateConnectionFunc = ( } } - if (sourceNode.data.type === 'if' && c.sourceHandle === 'value') { + if (resolvedSourceNode.data.type === 'if' && sourceHandle === 'value') { const ifInputEdges = filteredEdges.filter( - (e) => e.target === sourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) + (e) => + e.target === resolvedSourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) ); const ifInputTypes = ifInputEdges.flatMap((edge) => { - if (edge.source === null || edge.source === undefined) { + const ifInputSource = getEffectiveSourceForEdge(edge, nodes, filteredEdges, templates); + if (!ifInputSource || typeof ifInputSource === 'string') { return []; } - if (edge.sourceHandle === null || edge.sourceHandle === undefined) { - return []; - } - const ifInputSourceNode = nodes.find((n) => n.id === edge.source); - if (!ifInputSourceNode) { - return []; - } - const ifInputSourceTemplate = templates[ifInputSourceNode.data.type]; - if (!ifInputSourceTemplate) { - return []; - } - const ifInputSourceFieldTemplate = ifInputSourceTemplate.outputs[edge.sourceHandle]; - if (!ifInputSourceFieldTemplate) { - return []; - } - return [ifInputSourceFieldTemplate.type]; + return [ifInputSource.fieldTemplate.type]; }); if (ifInputTypes.length > 0) { diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 8cd529deb7..9e82144e4e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -43,6 +43,12 @@ export const zNotesNodeData = z.object({ isOpen: z.boolean(), notes: z.string(), }); +export const zConnectorNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + label: z.string(), + isOpen: z.boolean(), +}); const zCurrentImageNodeData = z.object({ id: z.string().trim().min(1), type: z.literal('current_image'), @@ -52,6 +58,7 @@ const zCurrentImageNodeData = z.object({ export type NotesNodeData = z.infer; export type InvocationNodeData = z.infer; +export type ConnectorNodeData = z.infer; type CurrentImageNodeData = z.infer; const zInvocationNodeValidationSchema = z.looseObject({ @@ -70,6 +77,15 @@ const zNotesNodeValidationSchema = z.looseObject({ const zNotesNode = z.custom>((val) => zNotesNodeValidationSchema.safeParse(val).success); export type NotesNode = z.infer; +const zConnectorNodeValidationSchema = z.looseObject({ + type: z.literal('connector'), + data: zConnectorNodeData, +}); +const zConnectorNode = z.custom>( + (val) => zConnectorNodeValidationSchema.safeParse(val).success +); +export type ConnectorNode = z.infer; + const zCurrentImageNodeValidationSchema = z.looseObject({ type: z.literal('current_image'), data: zCurrentImageNodeData, @@ -79,12 +95,14 @@ const zCurrentImageNode = z.custom>( ); export type CurrentImageNode = z.infer; -export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]); +export const zAnyNode = z.union([zInvocationNode, zNotesNode, zConnectorNode, zCurrentImageNode]); export type AnyNode = z.infer; export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => Boolean(node && node.type === 'invocation'); export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isConnectorNode = (node?: AnyNode | null): node is ConnectorNode => + Boolean(node && node.type === 'connector'); // #endregion // #region NodeExecutionState diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 66e69ec585..34f98eb289 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -3,7 +3,7 @@ import { z } from 'zod'; import type { FieldType } from './field'; import { zFieldIdentifier } from './field'; -import { zInvocationNodeData, zNotesNodeData } from './invocation'; +import { zConnectorNodeData, zInvocationNodeData, zNotesNodeData } from './invocation'; // #region Workflow misc const zXYPosition = z @@ -31,7 +31,13 @@ const zWorkflowNotesNode = z.object({ data: zNotesNodeData, position: zXYPosition, }); -const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); +const zWorkflowConnectorNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + data: zConnectorNodeData, + position: zXYPosition, +}); +const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode, zWorkflowConnectorNode]); type WorkflowInvocationNode = z.infer; @@ -377,7 +383,7 @@ export const zWorkflowV3 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('3.0.0'), + version: z.literal('4.0.0'), }), // Use the validated form schema! form: zValidatedBuilderForm, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts new file mode 100644 index 0000000000..b44c6b38cd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts @@ -0,0 +1,194 @@ +import { deepClone } from 'common/util/deepClone'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; +import { describe, expect, it } from 'vitest'; + +import { buildNodesGraph } from './buildNodesGraph'; + +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + +const buildState = (nodes: unknown[], edges: unknown[]) => + ({ + nodes: { + present: { + _version: 1, + nodes, + edges, + formFieldInitialValues: {}, + id: undefined, + name: '', + author: '', + description: '', + version: '', + contact: '', + tags: '', + notes: '', + exposedFields: [], + meta: { version: '4.0.0', category: 'user' }, + form: { + rootElementId: 'root', + elements: { + root: { + id: 'root', + type: 'container', + data: { layout: 'column', children: [] }, + }, + }, + }, + }, + }, + gallery: { + autoAddBoardId: 'none', + selection: [], + }, + }) as unknown as Parameters[0]; + +describe('buildNodesGraph', () => { + it('flattens a single connector to one direct execution edge', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('flattens chained connectors transitively', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const state = buildState( + [source, target, connectorA, connectorB], + [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('fans out through a connector into multiple execution edges', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, targetA, targetB, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetA.id, field: 'a' }, + }, + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetB.id, field: 'width' }, + }, + ]); + }); + + it('drops unresolved connector paths from the execution graph', () => { + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState([target, connector], [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([]); + }); + + it('deduplicates effective execution edges created by flattening', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + buildEdge(source.id, 'value', target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('still omits explicit destination input values when the flattened edge exists', () => { + const source = buildNode(add); + const target = deepClone(buildNode(sub)); + const connector = buildConnectorNode('connector-1'); + const inputA = target.data.inputs.a; + expect(inputA).toBeDefined(); + if (!inputA) { + throw new Error('Missing input a'); + } + inputA.value = 'not-an-integer' as never; + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + expect(graph.nodes[target.id]).not.toHaveProperty('a'); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts index d83555e558..50052c806c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts @@ -4,10 +4,11 @@ import { omit, reduce } from 'es-toolkit/compat'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopology'; import type { BoardField } from 'features/nodes/types/common'; import type { BoardFieldInputInstance } from 'features/nodes/types/field'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field'; -import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { AnyInvocation, Graph } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -96,12 +97,57 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require const filteredNodeIds = filteredNodes.map(({ id }) => id); // skip out the "dummy" edges between collapsed nodes - const filteredEdges = edges - .filter((edge) => edge.type !== 'collapsed') - .filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target)); + const flattenedEdges = edges + .filter((edge) => edge.type === 'default') + .flatMap((edge) => { + const targetNode = nodes.find((node) => node.id === edge.target); + if (!targetNode || !isInvocationNode(targetNode) || !isExecutableNode(targetNode)) { + return []; + } + + const sourceNode = nodes.find((node) => node.id === edge.source); + if (!sourceNode) { + return []; + } + + if (isInvocationNode(sourceNode)) { + if (!isExecutableNode(sourceNode) || !filteredNodeIds.includes(sourceNode.id)) { + return []; + } + return [edge]; + } + + if (isConnectorNode(sourceNode)) { + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource || !filteredNodeIds.includes(resolvedSource.nodeId)) { + return []; + } + return [ + { + ...edge, + id: `flattened-${resolvedSource.nodeId}-${resolvedSource.fieldName}-${edge.target}-${edge.targetHandle}`, + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + }, + ]; + } + + return []; + }) + .filter((edge, index, allEdges) => { + return ( + allEdges.findIndex( + (candidate) => + candidate.source === edge.source && + candidate.sourceHandle === edge.sourceHandle && + candidate.target === edge.target && + candidate.targetHandle === edge.targetHandle + ) === index + ); + }); // Reduce the node editor edges into invocation graph edges - const parsedEdges = filteredEdges.reduce>((edgesAccumulator, edge) => { + const parsedEdges = flattenedEdges.reduce>((edgesAccumulator, edge) => { const { source, target, sourceHandle, targetHandle } = edge; if (!sourceHandle || !targetHandle) { diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts new file mode 100644 index 0000000000..18cb45fb1c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts @@ -0,0 +1,21 @@ +import type { XYPosition } from '@xyflow/react'; +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; +import type { ConnectorNode } from 'features/nodes/types/invocation'; +import { v4 as uuidv4 } from 'uuid'; + +export const buildConnectorNode = (position: XYPosition): ConnectorNode => { + const nodeId = uuidv4(); + const node: ConnectorNode = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'connector', + position, + data: { + id: nodeId, + type: 'connector', + isOpen: true, + label: 'Connector', + }, + }; + return node; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 17cd0a33a7..21f0c38009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -5,7 +5,7 @@ import { parseify } from 'common/util/serialize'; import { pick } from 'es-toolkit/compat'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { NodesState } from 'features/nodes/store/types'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; @@ -42,6 +42,9 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => { if (isInvocationNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); + } else if (isConnectorNode(node) && node.type) { + const { id, type, data, position } = node; + newWorkflow.nodes.push({ id, type, data, position }); } else if (isNotesNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts index 05efc26fea..c09f4e1729 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts @@ -30,7 +30,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor description: '', meta: { category: 'user', - version: '3.0.0', + version: '4.0.0', }, notes: '', tags: '', diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 32971a02d0..638e66806d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -70,9 +70,11 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { }; const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { - // Bump version - (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; - // Parsing strips out any extra properties not in the latest version + return migrateV3toV4(workflowToMigrate as unknown as WorkflowV3); +}; + +const migrateV3toV4 = (workflowToMigrate: WorkflowV3): WorkflowV3 => { + workflowToMigrate.meta.version = '4.0.0'; return zWorkflowV3.parse(workflowToMigrate); }; @@ -100,6 +102,10 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { workflow = migrateV2toV3(v2); } + if (get(workflow, 'meta.version') === '3.0.0') { + workflow = migrateV3toV4(workflow as WorkflowV3); + } + // We should now have a V3 workflow const migratedWorkflow = zWorkflowV3.parse(workflow); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts index 1442a3475e..c1d0858831 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts @@ -1,4 +1,5 @@ import { get } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { getDefaultForm } from 'features/nodes/types/workflow'; @@ -7,6 +8,17 @@ import { describe, expect, it } from 'vitest'; //TODO(psyche): Test workflow validation for form builder fields describe('validateWorkflow', () => { + const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, + }); const getWorkflow = (): WorkflowV3 => ({ name: '', author: '', @@ -17,7 +29,7 @@ describe('validateWorkflow', () => { notes: '', exposedFields: [], form: getDefaultForm(), - meta: { version: '3.0.0', category: 'user' }, + meta: { version: '4.0.0', category: 'user' }, nodes: [ { id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad', @@ -104,6 +116,7 @@ describe('validateWorkflow', () => { }); expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined(); + expect(validationResult.workflow.meta.version).toBe('4.0.0'); }); it('should reset boards that are inaccessible', async () => { const validationResult = await validateWorkflow({ @@ -127,4 +140,138 @@ describe('validateWorkflow', () => { expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined(); }); + + it('should delete malformed connector edges with invalid handles', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: 'connector-1', + targetHandle: 'wrong', + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should delete connector edges with missing endpoints', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: 'missing-node', + sourceHandle: 'value', + target: 'connector-1', + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should repair invalid multi-input connector state predictably by keeping the first valid input edge', async () => { + const workflow = getWorkflow(); + const loader2 = structuredClone(workflow.nodes[0]!); + loader2.id = 'second-loader'; + loader2.data.id = 'second-loader'; + const connector = buildConnectorNode('connector-1'); + workflow.nodes.push(loader2, connector); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + workflow.edges.push({ + id: 'e2', + type: 'default', + source: loader2.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([ + { + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }, + ]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should retain isolated connectors during workflow validation', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.nodes.find((node) => node.id === 'connector-1')).toBeDefined(); + expect(validationResult.warnings).toEqual([]); + }); + + it('should retain unresolved connector output edges that establish downstream constraints in the editor', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + const unresolvedEdge = { + id: 'e1', + type: 'default' as const, + source: 'connector-1', + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: workflow.nodes[1]!.id, + targetHandle: 'image', + }; + workflow.edges.push(unresolvedEdge); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([unresolvedEdge]); + expect(validationResult.warnings).toEqual([]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index b86870d450..448214defe 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,7 @@ import { parseify } from 'common/util/serialize'; import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { isBoardFieldInputInstance, isImageFieldCollectionInputInstance, @@ -149,8 +150,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise(); + const validEdges = []; for (const edge of edges) { // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. @@ -169,7 +169,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise !edgesToDelete.has(id)); + _workflow.edges = validEdges; // Migrated exposed fields to form elements if they exist and the form does not // Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!