From dc1e8048878d37d0128c78718ddc1ec6aeb50cab Mon Sep 17 00:00:00 2001 From: CrypticWit <34725614+CrypticWit@users.noreply.github.com> Date: Fri, 29 Sep 2023 21:12:57 +1300 Subject: [PATCH] Workflow editor improvements - add node from empty connection and auto-connect to empy handle. (#4684) * Initial commit of edge drag feature. * Fixed build warnings * code cleanup and drag to existing node * improved isValidConnection check * fixed build issues, removed cyclic dependency * edge created nodes now spawn at cursor * Add Node popover will no longer show when using drag to delete an edge. * Fixed collection handling, added priority for handles matching name of source handle, removed current image/notes nodes from filtered list * Fixed not properly clearing startParams when closing the Add Node popover * fix(ui): do not allow Collect -> Iterate connection This can be removed when #3956 is resolved * feat(ui): use existing node validation logic in add-node-on-drop This logic handles a number of special cases --------- Co-authored-by: Millun Atluri Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> --- .../flow/AddNodePopover/AddNodePopover.tsx | 62 ++++++--- .../features/nodes/components/flow/Flow.tsx | 4 +- .../nodes/hooks/useIsValidConnection.ts | 28 +--- .../nodes/store/nodesPersistDenylist.ts | 3 + .../src/features/nodes/store/nodesSlice.ts | 96 ++++++++++++- .../web/src/features/nodes/store/types.ts | 4 + .../store/util/findConnectionToValidHandle.ts | 126 ++++++++++++++++++ .../nodes/store/util/getIsGraphAcyclic.ts | 26 ++++ .../util/makeIsConnectionValidSelector.ts | 2 +- .../util/validateSourceAndTargetTypes.ts | 7 + 10 files changed, 308 insertions(+), 50 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 9ab413a98f..5ddd1d4ece 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -17,14 +17,15 @@ import { addNodePopoverOpened, nodeAdded, } from 'features/nodes/store/nodesSlice'; -import { map } from 'lodash-es'; +import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; +import { filter, map, some } from 'lodash-es'; import { memo, useCallback, useRef } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; import { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; +import { useTranslation } from 'react-i18next'; import 'reactflow/dist/style.css'; import { AnyInvocationType } from 'services/events/types'; import { AddNodePopoverSelectItem } from './AddNodePopoverSelectItem'; -import { useTranslation } from 'react-i18next'; type NodeTemplate = { label: string; @@ -33,7 +34,7 @@ type NodeTemplate = { tags: string[]; }; -const filter = (value: string, item: NodeTemplate) => { +const selectFilter = (value: string, item: NodeTemplate) => { const regex = new RegExp( value .trim() @@ -55,10 +56,34 @@ const AddNodePopover = () => { const toaster = useAppToaster(); const { t } = useTranslation(); + const fieldFilter = useAppSelector( + (state) => state.nodes.currentConnectionFieldType + ); + const handleFilter = useAppSelector( + (state) => state.nodes.connectionStartParams?.handleType + ); + const selector = createSelector( [stateSelector], ({ nodes }) => { - const data: NodeTemplate[] = map(nodes.nodeTemplates, (template) => { + // If we have a connection in progress, we need to filter the node choices + const filteredNodeTemplates = fieldFilter + ? filter(nodes.nodeTemplates, (template) => { + const handles = + handleFilter == 'source' ? template.inputs : template.outputs; + + return some(handles, (handle) => { + const sourceType = + handleFilter == 'source' ? fieldFilter : handle.type; + const targetType = + handleFilter == 'target' ? fieldFilter : handle.type; + + return validateSourceAndTargetTypes(sourceType, targetType); + }); + }) + : map(nodes.nodeTemplates); + + const data: NodeTemplate[] = map(filteredNodeTemplates, (template) => { return { label: template.title, value: template.type, @@ -67,19 +92,22 @@ const AddNodePopover = () => { }; }); - data.push({ - label: t('nodes.currentImage'), - value: 'current_image', - description: t('nodes.currentImageDescription'), - tags: ['progress'], - }); + //We only want these nodes if we're not filtered + if (fieldFilter === null) { + data.push({ + label: t('nodes.currentImage'), + value: 'current_image', + description: t('nodes.currentImageDescription'), + tags: ['progress'], + }); - data.push({ - label: t('nodes.notes'), - value: 'notes', - description: t('nodes.notesDescription'), - tags: ['notes'], - }); + data.push({ + label: t('nodes.notes'), + value: 'notes', + description: t('nodes.notesDescription'), + tags: ['notes'], + }); + } data.sort((a, b) => a.label.localeCompare(b.label)); @@ -190,7 +218,7 @@ const AddNodePopover = () => { maxDropdownHeight={400} nothingFound={t('nodes.noMatchingNodes')} itemComponent={AddNodePopoverSelectItem} - filter={filter} + filter={selectFilter} onChange={handleChange} hoverOnSearchChange={true} onDropdownClose={onClose} diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index e2ff7c5bb0..25194582a3 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -30,6 +30,7 @@ import { connectionEnded, connectionMade, connectionStarted, + edgeChangeStarted, edgeAdded, edgeDeleted, edgesChanged, @@ -119,7 +120,7 @@ export const Flow = () => { ); const onConnectEnd: OnConnectEnd = useCallback(() => { - dispatch(connectionEnded()); + dispatch(connectionEnded({ cursorPosition: cursorPosition.current })); }, [dispatch]); const onEdgesDelete: OnEdgesDelete = useCallback( @@ -194,6 +195,7 @@ export const Flow = () => { edgeUpdateMouseEvent.current = e; // always delete the edge when starting an updated dispatch(edgeDeleted(edge.id)); + dispatch(edgeChangeStarted()); }, [dispatch] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index a57787556c..c88d4758af 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,9 +1,9 @@ // TODO: enable this at some point -import graphlib from '@dagrejs/graphlib'; import { useAppSelector } from 'app/store/storeHooks'; import { useCallback } from 'react'; -import { Connection, Edge, Node, useReactFlow } from 'reactflow'; +import { Connection, Node, useReactFlow } from 'reactflow'; import { validateSourceAndTargetTypes } from '../store/util/validateSourceAndTargetTypes'; +import { getIsGraphAcyclic } from '../store/util/getIsGraphAcyclic'; import { InvocationNodeData } from '../types/types'; /** @@ -87,27 +87,3 @@ export const useIsValidConnection = () => { return isValidConnection; }; - -export const getIsGraphAcyclic = ( - source: string, - target: string, - nodes: Node[], - edges: Edge[] -) => { - // construct graphlib graph from editor state - const g = new graphlib.Graph(); - - nodes.forEach((n) => { - g.setNode(n.id); - }); - - edges.forEach((e) => { - g.setEdge(e.source, e.target); - }); - - // add the candidate edge - g.setEdge(source, target); - - // check if the graph is acyclic - return graphlib.alg.isAcyclic(g); -}; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts index acf9918a89..64fee2293f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts @@ -12,4 +12,7 @@ export const nodesPersistDenylist: (keyof NodesState)[] = [ 'isReady', 'nodesToCopy', 'edgesToCopy', + 'connectionMade', + 'modifyingEdge', + 'addNewNodePosition', ]; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 1b3a5ca929..768017d86d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -60,6 +60,7 @@ import { } from '../types/types'; import { NodesState } from './types'; import { findUnoccupiedPosition } from './util/findUnoccupiedPosition'; +import { findConnectionToValidHandle } from './util/findConnectionToValidHandle'; export const WORKFLOW_FORMAT_VERSION = '1.0.0'; @@ -92,6 +93,9 @@ export const initialNodesState: NodesState = { isReady: false, connectionStartParams: null, currentConnectionFieldType: null, + connectionMade: false, + modifyingEdge: false, + addNewNodePosition: null, shouldShowFieldTypeLegend: false, shouldShowMinimapPanel: true, shouldValidateGraph: true, @@ -153,8 +157,8 @@ const nodesSlice = createSlice({ const node = action.payload; const position = findUnoccupiedPosition( state.nodes, - node.position.x, - node.position.y + state.addNewNodePosition?.x ?? node.position.x, + state.addNewNodePosition?.y ?? node.position.y ); node.position = position; node.selected = true; @@ -179,6 +183,38 @@ const nodesSlice = createSlice({ nodeId: node.id, ...initialNodeExecutionState, }; + + if (state.connectionStartParams) { + const { nodeId, handleId, handleType } = state.connectionStartParams; + if ( + nodeId && + handleId && + handleType && + state.currentConnectionFieldType + ) { + const newConnection = findConnectionToValidHandle( + node, + state.nodes, + state.edges, + nodeId, + handleId, + handleType, + state.currentConnectionFieldType + ); + if (newConnection) { + state.edges = addEdge( + { ...newConnection, type: 'default' }, + state.edges + ); + } + } + } + + state.connectionStartParams = null; + state.currentConnectionFieldType = null; + }, + edgeChangeStarted: (state) => { + state.modifyingEdge = true; }, edgesChanged: (state, action: PayloadAction) => { state.edges = applyEdgeChanges(action.payload, state.edges); @@ -195,6 +231,7 @@ const nodesSlice = createSlice({ }, connectionStarted: (state, action: PayloadAction) => { state.connectionStartParams = action.payload; + state.connectionMade = state.modifyingEdge; const { nodeId, handleId, handleType } = action.payload; if (!nodeId || !handleId) { return; @@ -219,10 +256,53 @@ const nodesSlice = createSlice({ { ...action.payload, type: 'default' }, state.edges ); + + state.connectionMade = true; }, - connectionEnded: (state) => { - state.connectionStartParams = null; - state.currentConnectionFieldType = null; + connectionEnded: (state, action) => { + if (!state.connectionMade) { + if (state.mouseOverNode) { + const nodeIndex = state.nodes.findIndex( + (n) => n.id === state.mouseOverNode + ); + const mouseOverNode = state.nodes?.[nodeIndex]; + if (mouseOverNode && state.connectionStartParams) { + const { nodeId, handleId, handleType } = + state.connectionStartParams; + if ( + nodeId && + handleId && + handleType && + state.currentConnectionFieldType + ) { + const newConnection = findConnectionToValidHandle( + mouseOverNode, + state.nodes, + state.edges, + nodeId, + handleId, + handleType, + state.currentConnectionFieldType + ); + if (newConnection) { + state.edges = addEdge( + { ...newConnection, type: 'default' }, + state.edges + ); + } + } + } + state.connectionStartParams = null; + state.currentConnectionFieldType = null; + } else { + state.addNewNodePosition = action.payload.cursorPosition; + state.isAddNodePopoverOpen = true; + } + } else { + state.connectionStartParams = null; + state.currentConnectionFieldType = null; + } + state.modifyingEdge = false; }, workflowExposedFieldAdded: ( state, @@ -835,10 +915,15 @@ const nodesSlice = createSlice({ }); }, addNodePopoverOpened: (state) => { + state.addNewNodePosition = null; //Create the node in viewport center by default state.isAddNodePopoverOpen = true; }, addNodePopoverClosed: (state) => { state.isAddNodePopoverOpen = false; + + //Make sure these get reset if we close the popover and haven't selected a node + state.connectionStartParams = null; + state.currentConnectionFieldType = null; }, addNodePopoverToggled: (state) => { state.isAddNodePopoverOpen = !state.isAddNodePopoverOpen; @@ -913,6 +998,7 @@ export const { connectionMade, connectionStarted, edgeDeleted, + edgeChangeStarted, edgesChanged, edgesDeleted, edgeUpdated, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 78410c2dba..f6bfa7cad8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -4,6 +4,7 @@ import { OnConnectStartParams, SelectionMode, Viewport, + XYPosition, } from 'reactflow'; import { FieldIdentifier, @@ -21,6 +22,8 @@ export type NodesState = { nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; currentConnectionFieldType: FieldType | null; + connectionMade: boolean; + modifyingEdge: boolean; shouldShowFieldTypeLegend: boolean; shouldShowMinimapPanel: boolean; shouldValidateGraph: boolean; @@ -39,5 +42,6 @@ export type NodesState = { nodesToCopy: Node[]; edgesToCopy: Edge[]; isAddNodePopoverOpen: boolean; + addNewNodePosition: XYPosition | null; selectionMode: SelectionMode; }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts new file mode 100644 index 0000000000..69386c1f23 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -0,0 +1,126 @@ +import { Connection, HandleType } from 'reactflow'; +import { Node, Edge } from 'reactflow'; +import { + FieldType, + InputFieldValue, + OutputFieldValue, +} from 'features/nodes/types/types'; + +import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; +import { getIsGraphAcyclic } from './getIsGraphAcyclic'; + +const isValidConnection = ( + edges: Edge[], + handleCurrentType: HandleType, + handleCurrentFieldType: FieldType, + node: Node, + handle: InputFieldValue | OutputFieldValue +) => { + let isValidConnection = true; + if (handleCurrentType === 'source') { + if ( + edges.find((edge) => { + return edge.target === node.id && edge.targetHandle === handle.name; + }) + ) { + isValidConnection = false; + } + } else { + if ( + edges.find((edge) => { + return edge.source === node.id && edge.sourceHandle === handle.name; + }) + ) { + isValidConnection = false; + } + } + + if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) { + isValidConnection = false; + } + + return isValidConnection; +}; + +export const findConnectionToValidHandle = ( + node: Node, + nodes: Node[], + edges: Edge[], + handleCurrentNodeId: string, + handleCurrentName: string, + handleCurrentType: HandleType, + handleCurrentFieldType: FieldType +): Connection | null => { + if (node.id === handleCurrentNodeId) { + return null; + } + + const handles = + handleCurrentType == 'source' ? node.data.inputs : node.data.outputs; + + //Prioritize handles whos name matches the node we're coming from + if (handles[handleCurrentName]) { + const handle = handles[handleCurrentName]; + + const sourceID = + handleCurrentType == 'source' ? handleCurrentNodeId : node.id; + const targetID = + handleCurrentType == 'source' ? node.id : handleCurrentNodeId; + const sourceHandle = + handleCurrentType == 'source' ? handleCurrentName : handle.name; + const targetHandle = + handleCurrentType == 'source' ? handle.name : handleCurrentName; + + const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges); + + const valid = isValidConnection( + edges, + handleCurrentType, + handleCurrentFieldType, + node, + handle + ); + + if (isGraphAcyclic && valid) { + return { + source: sourceID, + sourceHandle: sourceHandle, + target: targetID, + targetHandle: targetHandle, + }; + } + } + + for (const handleName in handles) { + const handle = handles[handleName]; + + const sourceID = + handleCurrentType == 'source' ? handleCurrentNodeId : node.id; + const targetID = + handleCurrentType == 'source' ? node.id : handleCurrentNodeId; + const sourceHandle = + handleCurrentType == 'source' ? handleCurrentName : handle.name; + const targetHandle = + handleCurrentType == 'source' ? handle.name : handleCurrentName; + + const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges); + + const valid = isValidConnection( + edges, + handleCurrentType, + handleCurrentFieldType, + node, + handle + ); + + if (isGraphAcyclic && valid) { + return { + source: sourceID, + sourceHandle: sourceHandle, + target: targetID, + targetHandle: targetHandle, + }; + } + } + return null; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts b/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts new file mode 100644 index 0000000000..421813a687 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/getIsGraphAcyclic.ts @@ -0,0 +1,26 @@ +import graphlib from '@dagrejs/graphlib'; +import { Edge, Node } from 'reactflow'; + +export const getIsGraphAcyclic = ( + source: string, + target: string, + nodes: Node[], + edges: Edge[] +) => { + // construct graphlib graph from editor state + const g = new graphlib.Graph(); + + nodes.forEach((n) => { + g.setNode(n.id); + }); + + edges.forEach((e) => { + g.setEdge(e.source, e.target); + }); + + // add the candidate edge + g.setEdge(source, target); + + // check if the graph is acyclic + return graphlib.alg.isAcyclic(g); +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts index 6343240a88..57dd284b88 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeIsConnectionValidSelector.ts @@ -1,6 +1,6 @@ import { createSelector } from '@reduxjs/toolkit'; import { stateSelector } from 'app/store/store'; -import { getIsGraphAcyclic } from 'features/nodes/hooks/useIsValidConnection'; +import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { FieldType } from 'features/nodes/types/types'; import i18n from 'i18next'; import { HandleType } from 'reactflow'; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts index 4f0be3329a..8c2bef34fe 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateSourceAndTargetTypes.ts @@ -10,6 +10,13 @@ export const validateSourceAndTargetTypes = ( sourceType: FieldType, targetType: FieldType ) => { + // TODO: There's a bug with Collect -> Iterate nodes: + // https://github.com/invoke-ai/InvokeAI/issues/3956 + // Once this is resolved, we can remove this check. + if (sourceType === 'Collection' && targetType === 'Collection') { + return false; + } + if (sourceType === targetType) { return true; }