From cd785ba64b7b9ceb95f2003b3e5e5f09f115c99d Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Sun, 16 Feb 2025 11:37:26 +1000 Subject: [PATCH] perf(ui): optimize field handle/title/etc rendering --- .../flow/nodes/Invocation/InvocationNode.tsx | 23 +++-- .../nodes/Invocation/InvocationNodeHeader.tsx | 25 +++-- .../fields/InputFieldEditModeNodes.tsx | 91 ++++++++++++++----- .../Invocation/fields/InputFieldHandle.tsx | 91 ++++++++++++++++--- .../Invocation/fields/InputFieldTitle.tsx | 6 +- .../Invocation/fields/OutputFieldHandle.tsx | 91 ++++++++++++++++--- .../Invocation/fields/OutputFieldTitle.tsx | 6 +- .../flow/nodes/common/NodeCollapseButton.tsx | 19 ++-- .../nodes/hooks/useFieldConnectionState.ts | 8 +- .../nodes/hooks/useIsValidConnection.ts | 4 +- .../features/nodes/hooks/useNodeCopyPaste.ts | 6 +- .../store/util/getFirstValidConnection.ts | 8 +- .../store/util/makeConnectionErrorSelector.ts | 12 +-- .../store/util/validateConnection.test.ts | 42 ++++----- .../nodes/store/util/validateConnection.ts | 52 +++++------ 15 files changed, 330 insertions(+), 154 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx index 60ab3295a4..83aae79f13 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx @@ -1,3 +1,4 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Flex, Grid, GridItem } from '@invoke-ai/ui-library'; import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate'; import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate'; @@ -20,6 +21,18 @@ type Props = { isOpen: boolean; }; +const sx: SystemStyleObject = { + flexDirection: 'column', + w: 'full', + h: 'full', + py: 2, + gap: 1, + borderBottomRadius: 'base', + '&[data-with-footer="true"]': { + borderBottomRadius: 0, + }, +}; + const InvocationNode = ({ nodeId, isOpen }: Props) => { const withFooter = useWithFooter(nodeId); @@ -28,15 +41,7 @@ const InvocationNode = ({ nodeId, isOpen }: Props) => { {isOpen && ( <> - + diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx index f0bbab5e12..c3ff0749dd 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx @@ -1,3 +1,4 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Flex } from '@invoke-ai/ui-library'; import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton'; import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle'; @@ -13,18 +14,22 @@ type Props = { isOpen: boolean; }; +const sx: SystemStyleObject = { + borderTopRadius: 'base', + alignItems: 'center', + justifyContent: 'space-between', + h: 8, + textAlign: 'center', + color: 'base.200', + borderBottomRadius: 'base', + '&[data-is-open="true"]': { + borderBottomRadius: 0, + }, +}; + const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => { return ( - + diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx index de327ef133..ef0bb38f07 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx @@ -1,3 +1,4 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Flex, FormControl, Spacer } from '@invoke-ai/ui-library'; import { InputFieldDescriptionPopover } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldDescriptionPopover'; import { InputFieldHandle } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle'; @@ -6,6 +7,7 @@ import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected'; import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid'; import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import { memo, useCallback, useRef, useState } from 'react'; import { InputFieldRenderer } from './InputFieldRenderer'; @@ -19,11 +21,72 @@ interface Props { export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => { const fieldTemplate = useInputFieldTemplate(nodeId, fieldName); + const isInvalid = useInputFieldIsInvalid(nodeId, fieldName); + const isConnected = useInputFieldIsConnected(nodeId, fieldName); + + if (fieldTemplate.input === 'connection' || isConnected) { + return ( + + ); + } + + return ( + + ); +}); + +InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes'; + +type CommonProps = { + nodeId: string; + fieldName: string; + isInvalid: boolean; + isConnected: boolean; + fieldTemplate: FieldInputTemplate; +}; + +const ConnectedOrConnectionField = memo(({ nodeId, fieldName, isInvalid, isConnected }: CommonProps) => { + return ( + + + + + + + ); +}); +ConnectedOrConnectionField.displayName = 'ConnectedOrConnectionField'; + +const directFieldSx: SystemStyleObject = { + orientation: 'vertical', + px: 2, + '&[data-is-dragging="true"]': { + opacity: 0.3, + }, + // Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a + // connection, the mouse up to end the connection won't fire, leaving the connection in-progress. + pointerEvents: 'auto', + '&[data-is-connected="true"]': { + pointerEvents: 'none', + }, +}; + +const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate }: CommonProps) => { const draggableRef = useRef(null); const dragHandleRef = useRef(null); const [isHovered, setIsHovered] = useState(false); - const isInvalid = useInputFieldIsInvalid(nodeId, fieldName); - const isConnected = useInputFieldIsConnected(nodeId, fieldName); const onMouseEnter = useCallback(() => { setIsHovered(true); @@ -35,30 +98,15 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => { const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef); - if (fieldTemplate.input === 'connection' || isConnected) { - return ( - - - - - - - - ); - } - return ( @@ -79,5 +127,4 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => { ); }); - -InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes'; +DirectField.displayName = 'DirectField'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx index b0f231e517..0111f4af2d 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx @@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library'; import { Handle, Position } from '@xyflow/react'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { - useConnectionValidationResult, + useConnectionErrorTKey, useIsConnectionInProgress, useIsConnectionStartField, } from 'features/nodes/hooks/useFieldConnectionState'; import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; +import type { FieldInputTemplate } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -60,33 +61,60 @@ const handleStyles = { } satisfies CSSProperties; export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => { - const { t } = useTranslation(); const fieldTemplate = useInputFieldTemplate(nodeId, fieldName); const fieldTypeName = useFieldTypeName(fieldTemplate.type); const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]); const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]); - const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target'); const isConnectionInProgress = useIsConnectionInProgress(); - const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target'); - const tooltip = useMemo(() => { - if (isConnectionInProgress && validationResult.messageTKey) { - return t(validationResult.messageTKey); - } - return fieldTypeName; - }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]); + if (isConnectionInProgress) { + return ( + + ); + } return ( - + + ); +}); + +InputFieldHandle.displayName = 'InputFieldHandle'; + +type HandleCommonProps = { + nodeId: string; + fieldName: string; + fieldTemplate: FieldInputTemplate; + fieldTypeName: string; + fieldColor: string; + isModelField: boolean; +}; + +const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => { + return ( + @@ -94,5 +122,38 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => { ); }); +IdleHandle.displayName = 'IdleHandle'; -InputFieldHandle.displayName = 'InputFieldHandle'; +const ConnectionInProgressHandle = memo( + ({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => { + const { t } = useTranslation(); + const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target'); + const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target'); + + const tooltip = useMemo(() => { + if (connectionError !== null) { + return t(connectionError); + } + return fieldTypeName; + }, [fieldTypeName, t, connectionError]); + + return ( + + + + + + ); + } +); +ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx index ae9001418d..d1eccd51a7 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx @@ -4,7 +4,7 @@ import { useAppDispatch } from 'app/store/storeHooks'; import { useEditable } from 'common/hooks/useEditable'; import { InputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldTooltipContent'; import { - useConnectionValidationResult, + useConnectionErrorTKey, useIsConnectionInProgress, useIsConnectionStartField, } from 'features/nodes/hooks/useFieldConnectionState'; @@ -47,7 +47,7 @@ export const InputFieldTitle = memo((props: Props) => { const isConnected = useInputFieldIsConnected(nodeId, fieldName); const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target'); const isConnectionInProgress = useIsConnectionInProgress(); - const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target'); + const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target'); const dispatch = useAppDispatch(); const defaultTitle = useMemo(() => fieldTemplateTitle || t('nodes.unknownField'), [fieldTemplateTitle, t]); @@ -76,7 +76,7 @@ export const InputFieldTitle = memo((props: Props) => { noOfLines={1} data-is-invalid={isInvalid} data-is-disabled={ - (isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected + (isConnectionInProgress && connectionError !== null && !isConnectionStartField) || isConnected } onDoubleClick={editable.startEditing} > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx index 7c9fd92d55..6f44702ad6 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx @@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library'; import { Handle, Position } from '@xyflow/react'; import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor'; import { - useConnectionValidationResult, + useConnectionErrorTKey, useIsConnectionInProgress, useIsConnectionStartField, } from 'features/nodes/hooks/useFieldConnectionState'; import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate'; import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType'; import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants'; +import type { FieldOutputTemplate } from 'features/nodes/types/field'; import type { CSSProperties } from 'react'; import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; @@ -60,33 +61,60 @@ const handleStyles = { } satisfies CSSProperties; export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => { - const { t } = useTranslation(); const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName); const fieldTypeName = useFieldTypeName(fieldTemplate.type); const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]); const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]); - const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source'); const isConnectionInProgress = useIsConnectionInProgress(); - const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source'); - const tooltip = useMemo(() => { - if (isConnectionInProgress && validationResult.messageTKey) { - return t(validationResult.messageTKey); - } - return fieldTypeName; - }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]); + if (isConnectionInProgress) { + return ( + + ); + } return ( - + + ); +}); + +OutputFieldHandle.displayName = 'OutputFieldHandle'; + +type HandleCommonProps = { + nodeId: string; + fieldName: string; + fieldTemplate: FieldOutputTemplate; + fieldTypeName: string; + fieldColor: string; + isModelField: boolean; +}; + +const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => { + return ( + @@ -94,5 +122,38 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => { ); }); +IdleHandle.displayName = 'IdleHandle'; -OutputFieldHandle.displayName = 'OutputFieldHandle'; +const ConnectionInProgressHandle = memo( + ({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => { + const { t } = useTranslation(); + const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target'); + const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target'); + + const tooltip = useMemo(() => { + if (connectionErrorTKey !== null) { + return t(connectionErrorTKey); + } + return fieldTypeName; + }, [fieldTypeName, t, connectionErrorTKey]); + + return ( + + + + + + ); + } +); +ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx index 311a917c2a..22a4bb9f38 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx @@ -2,7 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Text, Tooltip } from '@invoke-ai/ui-library'; import { OutputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTooltipContent'; import { - useConnectionValidationResult, + useConnectionErrorTKey, useIsConnectionInProgress, useIsConnectionStartField, } from 'features/nodes/hooks/useFieldConnectionState'; @@ -31,7 +31,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => { const isConnected = useInputFieldIsConnected(nodeId, fieldName); const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source'); const isConnectionInProgress = useIsConnectionInProgress(); - const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source'); + const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'source'); return ( { > diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx index 2117843887..b485a928b9 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx @@ -1,3 +1,4 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Icon, IconButton } from '@invoke-ai/ui-library'; import { useUpdateNodeInternals } from '@xyflow/react'; import { useAppDispatch } from 'app/store/storeHooks'; @@ -10,6 +11,15 @@ interface Props { isOpen: boolean; } +const iconSx: SystemStyleObject = { + transitionProperty: 'transform', + transitionDuration: 'normal', + transform: 'rotate(180deg)', + '&[data-is-open="true"]': { + transform: 'rotate(0deg)', + }, +}; + const NodeCollapseButton = ({ nodeId, isOpen }: Props) => { const dispatch = useAppDispatch(); const updateNodeInternals = useUpdateNodeInternals(); @@ -28,14 +38,7 @@ const NodeCollapseButton = ({ nodeId, isOpen }: Props) => { w={8} h={8} variant="link" - icon={ - - } + icon={} /> ); }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts index b860e0df22..62dfb723cd 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts @@ -10,18 +10,18 @@ import { import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector'; import { useMemo } from 'react'; -export const useConnectionValidationResult = (nodeId: string, fieldName: string, handleType: HandleType) => { +export const useConnectionErrorTKey = (nodeId: string, fieldName: string, handleType: HandleType): string | null => { const pendingConnection = useStore($pendingConnection); const templates = useStore($templates); const edgePendingUpdate = useStore($edgePendingUpdate); - const selectValidationResult = useMemo( + const selectConnectionError = useMemo( () => makeConnectionErrorSelector(templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate), [templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate] ); - const validationResult = useAppSelector(selectValidationResult); - return validationResult; + const connectionError = useAppSelector(selectConnectionError); + return connectionError; }; export const useIsConnectionStartField = (nodeId: string, fieldName: string, handleType: HandleType) => { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index e4337018bc..a4f47a6ce0 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -21,7 +21,7 @@ export const useIsValidConnection = (): IsValidConnection => { const edgePendingUpdate = $edgePendingUpdate.get(); const { nodes, edges } = selectNodesSlice(store.getState()); - const validationResult = validateConnection( + const connectionErrorTKey = validateConnection( { source, sourceHandle, target, targetHandle }, nodes, edges, @@ -30,7 +30,7 @@ export const useIsValidConnection = (): IsValidConnection => { shouldValidateGraph ); - return validationResult.isValid; + return connectionErrorTKey === null; }, [templates, shouldValidateGraph, store] ); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts index 3974857ea1..7fdcef98b6 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts @@ -126,7 +126,7 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => { assert(!isNil(targetHandle)); // Validate the edge before adding it - const validationResult = validateConnection( + const connectionErrorTKey = validateConnection( { source, sourceHandle, target, targetHandle }, validationNodes, validationEdges, @@ -135,10 +135,10 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => { true ); // If the edge is invalid, log a warning and skip it - if (!validationResult.isValid) { + if (connectionErrorTKey !== null) { log.warn( { edge: { source, sourceHandle, target, targetHandle } }, - `Invalid edge, cannot paste: ${t(validationResult.messageTKey)}` + `Invalid edge, cannot paste: ${t(connectionErrorTKey)}` ); return; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 2a88e2078a..9789e991a7 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -105,8 +105,8 @@ export const getTargetCandidateFields = ( const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { const c = { source, sourceHandle, target, targetHandle: field.name }; - const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); - return r.isValid; + const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return connectionErrorTKey === null; }); return targetCandidateFields; @@ -141,8 +141,8 @@ export const getSourceCandidateFields = ( const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => { const c = { source, sourceHandle: field.name, target, targetHandle }; - const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); - return r.isValid; + const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true); + return connectionErrorTKey === null; }); return sourceCandidateFields; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts index 00492b994e..fa10fcc3e6 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts @@ -1,8 +1,8 @@ +import { createSelector } from '@reduxjs/toolkit'; import type { HandleType } from '@xyflow/react'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types'; -import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import type { AnyEdge } from 'features/nodes/types/invocation'; /** @@ -25,18 +25,18 @@ export const makeConnectionErrorSelector = ( pendingConnection: PendingConnection | null, edgePendingUpdate: AnyEdge | null ) => { - return createMemoizedSelector(selectNodesSlice, (nodesSlice: NodesState) => { + return createSelector(selectNodesSlice, (nodesSlice: NodesState): string | null => { const { nodes, edges } = nodesSlice; if (!pendingConnection) { - return buildRejectResult('nodes.noConnectionInProgress'); + return 'nodes.noConnectionInProgress'; } if (handleType === pendingConnection.handleType) { if (handleType === 'source') { - return buildRejectResult('nodes.cannotConnectOutputToOutput'); + return 'nodes.cannotConnectOutputToOutput'; } - return buildRejectResult('nodes.cannotConnectInputToInput'); + return 'nodes.cannotConnectInputToInput'; } // we have to figure out which is the target and which is the source diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 19035afd54..c50094be06 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,13 +3,13 @@ import { set } from 'lodash-es'; import { describe, expect, it } from 'vitest'; import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; -import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection'; +import { validateConnection } from './validateConnection'; describe(validateConnection.name, () => { it('should reject invalid connection to self', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; const r = validateConnection(c, [], [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + expect(r).toEqual('nodes.cannotConnectToSelf'); }); describe('missing nodes', () => { @@ -19,12 +19,12 @@ describe(validateConnection.name, () => { it('should reject missing source node', () => { const r = validateConnection(c, [n2], [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.missingNode')); + expect(r).toEqual('nodes.missingNode'); }); it('should reject missing target node', () => { const r = validateConnection(c, [n1], [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.missingNode')); + expect(r).toEqual('nodes.missingNode'); }); }); @@ -36,12 +36,12 @@ describe(validateConnection.name, () => { it('should reject missing source template', () => { const r = validateConnection(c, nodes, [], { sub }, null); - expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + expect(r).toEqual('nodes.missingInvocationTemplate'); }); it('should reject missing target template', () => { const r = validateConnection(c, nodes, [], { add }, null); - expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate')); + expect(r).toEqual('nodes.missingInvocationTemplate'); }); }); @@ -53,13 +53,13 @@ describe(validateConnection.name, () => { it('should reject missing source field template', () => { const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, nodes, [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + expect(r).toEqual('nodes.missingFieldTemplate'); }); it('should reject missing target field template', () => { const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' }; const r = validateConnection(c, nodes, [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate')); + expect(r).toEqual('nodes.missingFieldTemplate'); }); }); @@ -69,19 +69,19 @@ describe(validateConnection.name, () => { it('should accept non-duplicate connections', () => { const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, [n1, n2], [], templates, null); - expect(r).toEqual(buildAcceptResult()); + expect(r).toEqual(null); }); it('should reject duplicate connections', () => { const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const e = buildEdge(n1.id, 'value', n2.id, 'a'); const r = validateConnection(c, [n1, n2], [e], templates, null); - expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection')); + expect(r).toEqual('nodes.cannotDuplicateConnection'); }); it('should accept duplicate connections if the duplicate is an ignored edge', () => { const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const e = buildEdge(n1.id, 'value', n2.id, 'a'); const r = validateConnection(c, [n1, n2], [e], templates, e); - expect(r).toEqual(buildAcceptResult()); + expect(r).toEqual(null); }); }); @@ -95,7 +95,7 @@ describe(validateConnection.name, () => { const n2 = buildNode(addWithDirectAField); const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null); - expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput')); + expect(r).toEqual('nodes.cannotConnectToDirectInput'); }); it('should reject connection to a collect node with mismatched item types', () => { @@ -107,7 +107,7 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' }; const r = validateConnection(c, nodes, edges, templates, null); - expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes')); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); }); it('should accept connection to a collect node with matching item types', () => { @@ -119,7 +119,7 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' }; const r = validateConnection(c, nodes, edges, templates, null); - expect(r).toEqual(buildAcceptResult()); + expect(r).toEqual(null); }); it('should reject connections to target field that is already connected', () => { @@ -131,7 +131,7 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, nodes, edges, templates, null); - expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection')); + expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection'); }); it('should accept connections to target field that is already connected (ignored edge)', () => { @@ -143,7 +143,7 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' }; const r = validateConnection(c, nodes, edges, templates, e1); - expect(r).toEqual(buildAcceptResult()); + expect(r).toEqual(null); }); it('should reject connections between invalid types', () => { @@ -152,7 +152,7 @@ describe(validateConnection.name, () => { const nodes = [n1, n2]; const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; const r = validateConnection(c, nodes, [], templates, null); - expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch')); + expect(r).toEqual('nodes.fieldTypesMustMatch'); }); it('should reject connections that would create cycles', () => { @@ -163,14 +163,14 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; const r = validateConnection(c, nodes, edges, templates, null); - expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + expect(r).toEqual('nodes.connectionWouldCreateCycle'); }); describe('non-strict mode', () => { it('should reject connections from self to self in non-strict mode', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; const r = validateConnection(c, [], [], templates, null, false); - expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf')); + expect(r).toEqual('nodes.cannotConnectToSelf'); }); it('should reject connections that create cycles in non-strict mode', () => { const n1 = buildNode(add); @@ -180,7 +180,7 @@ describe(validateConnection.name, () => { const edges = [e1]; const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }; const r = validateConnection(c, nodes, edges, templates, null, false); - expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle')); + expect(r).toEqual('nodes.connectionWouldCreateCycle'); }); it('should otherwise allow invalid connections in non-strict mode', () => { const n1 = buildNode(add); @@ -188,7 +188,7 @@ describe(validateConnection.name, () => { const nodes = [n1, n2]; const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' }; const r = validateConnection(c, nodes, [], templates, null, false); - expect(r).toEqual(buildAcceptResult()); + expect(r).toEqual(null); }); }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index 02a55e3121..03c66ba8d8 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -9,16 +9,6 @@ import type { SetNonNullable } from 'type-fest'; type Connection = SetNonNullable; -type ValidationResult = - | { - isValid: true; - messageTKey?: string; - } - | { - isValid: false; - messageTKey: string; - }; - type ValidateConnectionFunc = ( connection: Connection, nodes: AnyNode[], @@ -26,7 +16,7 @@ type ValidateConnectionFunc = ( templates: Templates, ignoreEdge: AnyEdge | null, strict?: boolean -) => ValidationResult; +) => string | null; const getEqualityPredicate = (c: Connection) => @@ -45,12 +35,16 @@ const getTargetEqualityPredicate = return e.target === c.target && e.targetHandle === c.targetHandle; }; -export const buildAcceptResult = (): ValidationResult => ({ isValid: true }); -export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey }); - -export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => { +export const validateConnection: ValidateConnectionFunc = ( + c, + nodes, + edges, + templates, + ignoreEdge, + strict = true +): string | null => { if (c.source === c.target) { - return buildRejectResult('nodes.cannotConnectToSelf'); + return 'nodes.cannotConnectToSelf'; } if (strict) { @@ -65,66 +59,66 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp if (filteredEdges.some(getEqualityPredicate(c))) { // We already have a connection from this source to this target - return buildRejectResult('nodes.cannotDuplicateConnection'); + return 'nodes.cannotDuplicateConnection'; } const sourceNode = nodes.find((n) => n.id === c.source); if (!sourceNode) { - return buildRejectResult('nodes.missingNode'); + return 'nodes.missingNode'; } const targetNode = nodes.find((n) => n.id === c.target); if (!targetNode) { - return buildRejectResult('nodes.missingNode'); + return 'nodes.missingNode'; } const sourceTemplate = templates[sourceNode.data.type]; if (!sourceTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); + return 'nodes.missingInvocationTemplate'; } const targetTemplate = templates[targetNode.data.type]; if (!targetTemplate) { - return buildRejectResult('nodes.missingInvocationTemplate'); + return 'nodes.missingInvocationTemplate'; } const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; if (!sourceFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); + return 'nodes.missingFieldTemplate'; } const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; if (!targetFieldTemplate) { - return buildRejectResult('nodes.missingFieldTemplate'); + return 'nodes.missingFieldTemplate'; } if (targetFieldTemplate.input === 'direct') { - return buildRejectResult('nodes.cannotConnectToDirectInput'); + return 'nodes.cannotConnectToDirectInput'; } if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { // Collect nodes shouldn't mix and match field types. const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { - return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'); + return 'nodes.cannotMixAndMatchCollectionItemTypes'; } } if (filteredEdges.find(getTargetEqualityPredicate(c))) { // CollectionItemField inputs can have multiple input connections if (targetFieldTemplate.type.name !== 'CollectionItemField') { - return buildRejectResult('nodes.inputMayOnlyHaveOneConnection'); + return 'nodes.inputMayOnlyHaveOneConnection'; } } if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) { - return buildRejectResult('nodes.fieldTypesMustMatch'); + return 'nodes.fieldTypesMustMatch'; } } if (getHasCycles(c.source, c.target, nodes, edges)) { - return buildRejectResult('nodes.connectionWouldCreateCycle'); + return 'nodes.connectionWouldCreateCycle'; } - return buildAcceptResult(); + return null; };