diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index f6091e4a13..df628ba5af 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -70,7 +70,14 @@ export const useConnection = () => { } const candidateTemplate = templates[candidateNode.data.type]; assert(candidateTemplate, `Template not found for node type: ${candidateNode.data.type}`); - const connection = getFirstValidConnection(templates, nodes, edges, pendingConnection, candidateNode, candidateTemplate); + const connection = getFirstValidConnection( + templates, + nodes, + edges, + pendingConnection, + candidateNode, + candidateTemplate + ); if (connection) { dispatch(connectionMade(connection)); } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts index 727c0932f7..9acd5722cf 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useCopyPaste.ts @@ -1,6 +1,12 @@ import { getStore } from 'app/store/nanostores/store'; import { deepClone } from 'common/util/deepClone'; -import { $copiedEdges,$copiedNodes,$cursorPos, selectionPasted, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { + $copiedEdges, + $copiedNodes, + $cursorPos, + selectionPasted, + selectNodesSlice, +} from 'features/nodes/store/nodesSlice'; import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupiedPosition'; import { v4 as uuidv4 } from 'uuid'; diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts index be8cfafa8b..4739a77e1c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/selectors.ts +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -4,7 +4,7 @@ import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/in import { isInvocationNode } from 'features/nodes/types/invocation'; import { assert } from 'tsafe'; -export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => { +const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode => { const node = nodesSlice.nodes.find((node) => node.id === nodeId); assert(isInvocationNode(node), `Node ${nodeId} is not an invocation node`); return node; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts index cd69640dca..4c47cb15b0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/findConnectionToValidHandle.ts @@ -1,120 +1,13 @@ import type { PendingConnection, Templates } from 'features/nodes/store/types'; import { getCollectItemType } from 'features/nodes/store/util/makeIsConnectionValidSelector'; -import type { FieldInputTemplate, FieldOutputTemplate, FieldType } from 'features/nodes/types/field'; import type { AnyNode, InvocationNode, InvocationNodeEdge, InvocationTemplate } from 'features/nodes/types/invocation'; -import { isInvocationNode } from 'features/nodes/types/invocation'; import { differenceWith, isEqual, map } from 'lodash-es'; -import type { Connection, Edge, HandleType, Node } from 'reactflow'; +import type { Connection } from 'reactflow'; import { assert } from 'tsafe'; import { getIsGraphAcyclic } from './getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from './validateSourceAndTargetTypes'; -const isValidConnection = ( - edges: Edge[], - handleCurrentType: HandleType, - handleCurrentFieldType: FieldType, - node: Node, - handle: FieldInputTemplate | FieldOutputTemplate -) => { - let isValidConnection = true; - if (handleCurrentType === 'source') { - if ( - edges.find((edge) => { - return edge.target === node.id && edge.targetHandle === handle.name; - }) - ) { - isValidConnection = false; - } - } else { - if ( - edges.find((edge) => { - return edge.source === node.id && edge.sourceHandle === handle.name; - }) - ) { - isValidConnection = false; - } - } - - if (!validateSourceAndTargetTypes(handleCurrentFieldType, handle.type)) { - isValidConnection = false; - } - - return isValidConnection; -}; - -export const findConnectionToValidHandle = ( - node: AnyNode, - nodes: AnyNode[], - edges: InvocationNodeEdge[], - templates: Templates, - handleCurrentNodeId: string, - handleCurrentName: string, - handleCurrentType: HandleType, - handleCurrentFieldType: FieldType -): Connection | null => { - if (node.id === handleCurrentNodeId || !isInvocationNode(node)) { - return null; - } - - const template = templates[node.data.type]; - - if (!template) { - return null; - } - - const handles = handleCurrentType === 'source' ? template.inputs : template.outputs; - - //Prioritize handles whos name matches the node we're coming from - const handle = handles[handleCurrentName]; - - if (handle) { - const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; - const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; - const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; - const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName; - - const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges); - - const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle); - - if (isGraphAcyclic && valid) { - return { - source: sourceID, - sourceHandle: sourceHandle, - target: targetID, - targetHandle: targetHandle, - }; - } - } - - for (const handleName in handles) { - const handle = handles[handleName]; - if (!handle) { - continue; - } - - const sourceID = handleCurrentType === 'source' ? handleCurrentNodeId : node.id; - const targetID = handleCurrentType === 'source' ? node.id : handleCurrentNodeId; - const sourceHandle = handleCurrentType === 'source' ? handleCurrentName : handle.name; - const targetHandle = handleCurrentType === 'source' ? handle.name : handleCurrentName; - - const isGraphAcyclic = getIsGraphAcyclic(sourceID, targetID, nodes, edges); - - const valid = isValidConnection(edges, handleCurrentType, handleCurrentFieldType, node, handle); - - if (isGraphAcyclic && valid) { - return { - source: sourceID, - sourceHandle: sourceHandle, - target: targetID, - targetHandle: targetHandle, - }; - } - } - return null; -}; - export const getFirstValidConnection = ( templates: Templates, nodes: AnyNode[], diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts index 7487fd488b..4a2e45abde 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts @@ -3,7 +3,7 @@ import { createSlice } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { SelectionMode } from 'reactflow'; -export type WorkflowSettingsState = { +type WorkflowSettingsState = { _version: 1; shouldShowMinimapPanel: boolean; shouldValidateGraph: boolean; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index c57a7213b8..d2d3d64cb0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -31,10 +31,7 @@ type ValidateWorkflowResult = { * @throws {WorkflowVersionError} If the workflow version is not recognized. * @throws {z.ZodError} If there is a validation error. */ -export const validateWorkflow = ( - workflow: unknown, - invocationTemplates: Templates -): ValidateWorkflowResult => { +export const validateWorkflow = (workflow: unknown, invocationTemplates: Templates): ValidateWorkflowResult => { // Parse the raw workflow data & migrate it to the latest version const _workflow = parseAndMigrateWorkflow(workflow);