From cd785ba64b7b9ceb95f2003b3e5e5f09f115c99d Mon Sep 17 00:00:00 2001
From: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
Date: Sun, 16 Feb 2025 11:37:26 +1000
Subject: [PATCH] perf(ui): optimize field handle/title/etc rendering
---
.../flow/nodes/Invocation/InvocationNode.tsx | 23 +++--
.../nodes/Invocation/InvocationNodeHeader.tsx | 25 +++--
.../fields/InputFieldEditModeNodes.tsx | 91 ++++++++++++++-----
.../Invocation/fields/InputFieldHandle.tsx | 91 ++++++++++++++++---
.../Invocation/fields/InputFieldTitle.tsx | 6 +-
.../Invocation/fields/OutputFieldHandle.tsx | 91 ++++++++++++++++---
.../Invocation/fields/OutputFieldTitle.tsx | 6 +-
.../flow/nodes/common/NodeCollapseButton.tsx | 19 ++--
.../nodes/hooks/useFieldConnectionState.ts | 8 +-
.../nodes/hooks/useIsValidConnection.ts | 4 +-
.../features/nodes/hooks/useNodeCopyPaste.ts | 6 +-
.../store/util/getFirstValidConnection.ts | 8 +-
.../store/util/makeConnectionErrorSelector.ts | 12 +--
.../store/util/validateConnection.test.ts | 42 ++++-----
.../nodes/store/util/validateConnection.ts | 52 +++++------
15 files changed, 330 insertions(+), 154 deletions(-)
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
index 60ab3295a4..83aae79f13 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNode.tsx
@@ -1,3 +1,4 @@
+import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate';
@@ -20,6 +21,18 @@ type Props = {
isOpen: boolean;
};
+const sx: SystemStyleObject = {
+ flexDirection: 'column',
+ w: 'full',
+ h: 'full',
+ py: 2,
+ gap: 1,
+ borderBottomRadius: 'base',
+ '&[data-with-footer="true"]': {
+ borderBottomRadius: 0,
+ },
+};
+
const InvocationNode = ({ nodeId, isOpen }: Props) => {
const withFooter = useWithFooter(nodeId);
@@ -28,15 +41,7 @@ const InvocationNode = ({ nodeId, isOpen }: Props) => {
{isOpen && (
<>
-
+
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx
index f0bbab5e12..c3ff0749dd 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeHeader.tsx
@@ -1,3 +1,4 @@
+import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
@@ -13,18 +14,22 @@ type Props = {
isOpen: boolean;
};
+const sx: SystemStyleObject = {
+ borderTopRadius: 'base',
+ alignItems: 'center',
+ justifyContent: 'space-between',
+ h: 8,
+ textAlign: 'center',
+ color: 'base.200',
+ borderBottomRadius: 'base',
+ '&[data-is-open="true"]': {
+ borderBottomRadius: 0,
+ },
+};
+
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
return (
-
+
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx
index de327ef133..ef0bb38f07 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldEditModeNodes.tsx
@@ -1,3 +1,4 @@
+import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, FormControl, Spacer } from '@invoke-ai/ui-library';
import { InputFieldDescriptionPopover } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldDescriptionPopover';
import { InputFieldHandle } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle';
@@ -6,6 +7,7 @@ import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
+import type { FieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useRef, useState } from 'react';
import { InputFieldRenderer } from './InputFieldRenderer';
@@ -19,11 +21,72 @@ interface Props {
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
+ const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
+ const isConnected = useInputFieldIsConnected(nodeId, fieldName);
+
+ if (fieldTemplate.input === 'connection' || isConnected) {
+ return (
+
+ );
+ }
+
+ return (
+
+ );
+});
+
+InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
+
+type CommonProps = {
+ nodeId: string;
+ fieldName: string;
+ isInvalid: boolean;
+ isConnected: boolean;
+ fieldTemplate: FieldInputTemplate;
+};
+
+const ConnectedOrConnectionField = memo(({ nodeId, fieldName, isInvalid, isConnected }: CommonProps) => {
+ return (
+
+
+
+
+
+
+ );
+});
+ConnectedOrConnectionField.displayName = 'ConnectedOrConnectionField';
+
+const directFieldSx: SystemStyleObject = {
+ orientation: 'vertical',
+ px: 2,
+ '&[data-is-dragging="true"]': {
+ opacity: 0.3,
+ },
+ // Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
+ // connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
+ pointerEvents: 'auto',
+ '&[data-is-connected="true"]': {
+ pointerEvents: 'none',
+ },
+};
+
+const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate }: CommonProps) => {
const draggableRef = useRef(null);
const dragHandleRef = useRef(null);
const [isHovered, setIsHovered] = useState(false);
- const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
- const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const onMouseEnter = useCallback(() => {
setIsHovered(true);
@@ -35,30 +98,15 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
- if (fieldTemplate.input === 'connection' || isConnected) {
- return (
-
-
-
-
-
-
-
- );
- }
-
return (
@@ -79,5 +127,4 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
);
});
-
-InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
+DirectField.displayName = 'DirectField';
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx
index b0f231e517..0111f4af2d 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle.tsx
@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
import { Handle, Position } from '@xyflow/react';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
- useConnectionValidationResult,
+ useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
+import type { FieldInputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -60,33 +61,60 @@ const handleStyles = {
} satisfies CSSProperties;
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
- const { t } = useTranslation();
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
- const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const isConnectionInProgress = useIsConnectionInProgress();
- const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
- const tooltip = useMemo(() => {
- if (isConnectionInProgress && validationResult.messageTKey) {
- return t(validationResult.messageTKey);
- }
- return fieldTypeName;
- }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
+ if (isConnectionInProgress) {
+ return (
+
+ );
+ }
return (
-
+
+ );
+});
+
+InputFieldHandle.displayName = 'InputFieldHandle';
+
+type HandleCommonProps = {
+ nodeId: string;
+ fieldName: string;
+ fieldTemplate: FieldInputTemplate;
+ fieldTypeName: string;
+ fieldColor: string;
+ isModelField: boolean;
+};
+
+const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
+ return (
+
@@ -94,5 +122,38 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
);
});
+IdleHandle.displayName = 'IdleHandle';
-InputFieldHandle.displayName = 'InputFieldHandle';
+const ConnectionInProgressHandle = memo(
+ ({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
+ const { t } = useTranslation();
+ const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
+ const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
+
+ const tooltip = useMemo(() => {
+ if (connectionError !== null) {
+ return t(connectionError);
+ }
+ return fieldTypeName;
+ }, [fieldTypeName, t, connectionError]);
+
+ return (
+
+
+
+
+
+ );
+ }
+);
+ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx
index ae9001418d..d1eccd51a7 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldTitle.tsx
@@ -4,7 +4,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { InputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldTooltipContent';
import {
- useConnectionValidationResult,
+ useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
@@ -47,7 +47,7 @@ export const InputFieldTitle = memo((props: Props) => {
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const isConnectionInProgress = useIsConnectionInProgress();
- const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
+ const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
const dispatch = useAppDispatch();
const defaultTitle = useMemo(() => fieldTemplateTitle || t('nodes.unknownField'), [fieldTemplateTitle, t]);
@@ -76,7 +76,7 @@ export const InputFieldTitle = memo((props: Props) => {
noOfLines={1}
data-is-invalid={isInvalid}
data-is-disabled={
- (isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected
+ (isConnectionInProgress && connectionError !== null && !isConnectionStartField) || isConnected
}
onDoubleClick={editable.startEditing}
>
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx
index 7c9fd92d55..6f44702ad6 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldHandle.tsx
@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
import { Handle, Position } from '@xyflow/react';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
- useConnectionValidationResult,
+ useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
+import type { FieldOutputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -60,33 +61,60 @@ const handleStyles = {
} satisfies CSSProperties;
export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
- const { t } = useTranslation();
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
- const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
const isConnectionInProgress = useIsConnectionInProgress();
- const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
- const tooltip = useMemo(() => {
- if (isConnectionInProgress && validationResult.messageTKey) {
- return t(validationResult.messageTKey);
- }
- return fieldTypeName;
- }, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
+ if (isConnectionInProgress) {
+ return (
+
+ );
+ }
return (
-
+
+ );
+});
+
+OutputFieldHandle.displayName = 'OutputFieldHandle';
+
+type HandleCommonProps = {
+ nodeId: string;
+ fieldName: string;
+ fieldTemplate: FieldOutputTemplate;
+ fieldTypeName: string;
+ fieldColor: string;
+ isModelField: boolean;
+};
+
+const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
+ return (
+
@@ -94,5 +122,38 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
);
});
+IdleHandle.displayName = 'IdleHandle';
-OutputFieldHandle.displayName = 'OutputFieldHandle';
+const ConnectionInProgressHandle = memo(
+ ({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
+ const { t } = useTranslation();
+ const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
+ const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
+
+ const tooltip = useMemo(() => {
+ if (connectionErrorTKey !== null) {
+ return t(connectionErrorTKey);
+ }
+ return fieldTypeName;
+ }, [fieldTypeName, t, connectionErrorTKey]);
+
+ return (
+
+
+
+
+
+ );
+ }
+);
+ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx
index 311a917c2a..22a4bb9f38 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTitle.tsx
@@ -2,7 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Text, Tooltip } from '@invoke-ai/ui-library';
import { OutputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTooltipContent';
import {
- useConnectionValidationResult,
+ useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
@@ -31,7 +31,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
const isConnectionInProgress = useIsConnectionInProgress();
- const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
+ const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'source');
return (
{
>
diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx
index 2117843887..b485a928b9 100644
--- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx
+++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NodeCollapseButton.tsx
@@ -1,3 +1,4 @@
+import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Icon, IconButton } from '@invoke-ai/ui-library';
import { useUpdateNodeInternals } from '@xyflow/react';
import { useAppDispatch } from 'app/store/storeHooks';
@@ -10,6 +11,15 @@ interface Props {
isOpen: boolean;
}
+const iconSx: SystemStyleObject = {
+ transitionProperty: 'transform',
+ transitionDuration: 'normal',
+ transform: 'rotate(180deg)',
+ '&[data-is-open="true"]': {
+ transform: 'rotate(0deg)',
+ },
+};
+
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();
@@ -28,14 +38,7 @@ const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
w={8}
h={8}
variant="link"
- icon={
-
- }
+ icon={}
/>
);
};
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts
index b860e0df22..62dfb723cd 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldConnectionState.ts
@@ -10,18 +10,18 @@ import {
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
import { useMemo } from 'react';
-export const useConnectionValidationResult = (nodeId: string, fieldName: string, handleType: HandleType) => {
+export const useConnectionErrorTKey = (nodeId: string, fieldName: string, handleType: HandleType): string | null => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const edgePendingUpdate = useStore($edgePendingUpdate);
- const selectValidationResult = useMemo(
+ const selectConnectionError = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate),
[templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate]
);
- const validationResult = useAppSelector(selectValidationResult);
- return validationResult;
+ const connectionError = useAppSelector(selectConnectionError);
+ return connectionError;
};
export const useIsConnectionStartField = (nodeId: string, fieldName: string, handleType: HandleType) => {
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
index e4337018bc..a4f47a6ce0 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts
@@ -21,7 +21,7 @@ export const useIsValidConnection = (): IsValidConnection => {
const edgePendingUpdate = $edgePendingUpdate.get();
const { nodes, edges } = selectNodesSlice(store.getState());
- const validationResult = validateConnection(
+ const connectionErrorTKey = validateConnection(
{ source, sourceHandle, target, targetHandle },
nodes,
edges,
@@ -30,7 +30,7 @@ export const useIsValidConnection = (): IsValidConnection => {
shouldValidateGraph
);
- return validationResult.isValid;
+ return connectionErrorTKey === null;
},
[templates, shouldValidateGraph, store]
);
diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts
index 3974857ea1..7fdcef98b6 100644
--- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts
+++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeCopyPaste.ts
@@ -126,7 +126,7 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
assert(!isNil(targetHandle));
// Validate the edge before adding it
- const validationResult = validateConnection(
+ const connectionErrorTKey = validateConnection(
{ source, sourceHandle, target, targetHandle },
validationNodes,
validationEdges,
@@ -135,10 +135,10 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
true
);
// If the edge is invalid, log a warning and skip it
- if (!validationResult.isValid) {
+ if (connectionErrorTKey !== null) {
log.warn(
{ edge: { source, sourceHandle, target, targetHandle } },
- `Invalid edge, cannot paste: ${t(validationResult.messageTKey)}`
+ `Invalid edge, cannot paste: ${t(connectionErrorTKey)}`
);
return;
}
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts
index 2a88e2078a..9789e991a7 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts
@@ -105,8 +105,8 @@ export const getTargetCandidateFields = (
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
const c = { source, sourceHandle, target, targetHandle: field.name };
- const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
- return r.isValid;
+ const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
+ return connectionErrorTKey === null;
});
return targetCandidateFields;
@@ -141,8 +141,8 @@ export const getSourceCandidateFields = (
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
const c = { source, sourceHandle: field.name, target, targetHandle };
- const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
- return r.isValid;
+ const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
+ return connectionErrorTKey === null;
});
return sourceCandidateFields;
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts
index 00492b994e..fa10fcc3e6 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/makeConnectionErrorSelector.ts
@@ -1,8 +1,8 @@
+import { createSelector } from '@reduxjs/toolkit';
import type { HandleType } from '@xyflow/react';
-import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
-import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
+import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { AnyEdge } from 'features/nodes/types/invocation';
/**
@@ -25,18 +25,18 @@ export const makeConnectionErrorSelector = (
pendingConnection: PendingConnection | null,
edgePendingUpdate: AnyEdge | null
) => {
- return createMemoizedSelector(selectNodesSlice, (nodesSlice: NodesState) => {
+ return createSelector(selectNodesSlice, (nodesSlice: NodesState): string | null => {
const { nodes, edges } = nodesSlice;
if (!pendingConnection) {
- return buildRejectResult('nodes.noConnectionInProgress');
+ return 'nodes.noConnectionInProgress';
}
if (handleType === pendingConnection.handleType) {
if (handleType === 'source') {
- return buildRejectResult('nodes.cannotConnectOutputToOutput');
+ return 'nodes.cannotConnectOutputToOutput';
}
- return buildRejectResult('nodes.cannotConnectInputToInput');
+ return 'nodes.cannotConnectInputToInput';
}
// we have to figure out which is the target and which is the source
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts
index 19035afd54..c50094be06 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts
@@ -3,13 +3,13 @@ import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
-import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
+import { validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
+ expect(r).toEqual('nodes.cannotConnectToSelf');
});
describe('missing nodes', () => {
@@ -19,12 +19,12 @@ describe(validateConnection.name, () => {
it('should reject missing source node', () => {
const r = validateConnection(c, [n2], [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.missingNode'));
+ expect(r).toEqual('nodes.missingNode');
});
it('should reject missing target node', () => {
const r = validateConnection(c, [n1], [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.missingNode'));
+ expect(r).toEqual('nodes.missingNode');
});
});
@@ -36,12 +36,12 @@ describe(validateConnection.name, () => {
it('should reject missing source template', () => {
const r = validateConnection(c, nodes, [], { sub }, null);
- expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
+ expect(r).toEqual('nodes.missingInvocationTemplate');
});
it('should reject missing target template', () => {
const r = validateConnection(c, nodes, [], { add }, null);
- expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
+ expect(r).toEqual('nodes.missingInvocationTemplate');
});
});
@@ -53,13 +53,13 @@ describe(validateConnection.name, () => {
it('should reject missing source field template', () => {
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
+ expect(r).toEqual('nodes.missingFieldTemplate');
});
it('should reject missing target field template', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
const r = validateConnection(c, nodes, [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
+ expect(r).toEqual('nodes.missingFieldTemplate');
});
});
@@ -69,19 +69,19 @@ describe(validateConnection.name, () => {
it('should accept non-duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], templates, null);
- expect(r).toEqual(buildAcceptResult());
+ expect(r).toEqual(null);
});
it('should reject duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
+ expect(r).toEqual('nodes.cannotDuplicateConnection');
});
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, e);
- expect(r).toEqual(buildAcceptResult());
+ expect(r).toEqual(null);
});
});
@@ -95,7 +95,7 @@ describe(validateConnection.name, () => {
const n2 = buildNode(addWithDirectAField);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
- expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
+ expect(r).toEqual('nodes.cannotConnectToDirectInput');
});
it('should reject connection to a collect node with mismatched item types', () => {
@@ -107,7 +107,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
- expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
+ expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes');
});
it('should accept connection to a collect node with matching item types', () => {
@@ -119,7 +119,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
- expect(r).toEqual(buildAcceptResult());
+ expect(r).toEqual(null);
});
it('should reject connections to target field that is already connected', () => {
@@ -131,7 +131,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
- expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
+ expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection');
});
it('should accept connections to target field that is already connected (ignored edge)', () => {
@@ -143,7 +143,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, e1);
- expect(r).toEqual(buildAcceptResult());
+ expect(r).toEqual(null);
});
it('should reject connections between invalid types', () => {
@@ -152,7 +152,7 @@ describe(validateConnection.name, () => {
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null);
- expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
+ expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should reject connections that would create cycles', () => {
@@ -163,14 +163,14 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
- expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
+ expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
describe('non-strict mode', () => {
it('should reject connections from self to self in non-strict mode', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null, false);
- expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
+ expect(r).toEqual('nodes.cannotConnectToSelf');
});
it('should reject connections that create cycles in non-strict mode', () => {
const n1 = buildNode(add);
@@ -180,7 +180,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null, false);
- expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
+ expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
it('should otherwise allow invalid connections in non-strict mode', () => {
const n1 = buildNode(add);
@@ -188,7 +188,7 @@ describe(validateConnection.name, () => {
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null, false);
- expect(r).toEqual(buildAcceptResult());
+ expect(r).toEqual(null);
});
});
});
diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts
index 02a55e3121..03c66ba8d8 100644
--- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts
+++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts
@@ -9,16 +9,6 @@ import type { SetNonNullable } from 'type-fest';
type Connection = SetNonNullable;
-type ValidationResult =
- | {
- isValid: true;
- messageTKey?: string;
- }
- | {
- isValid: false;
- messageTKey: string;
- };
-
type ValidateConnectionFunc = (
connection: Connection,
nodes: AnyNode[],
@@ -26,7 +16,7 @@ type ValidateConnectionFunc = (
templates: Templates,
ignoreEdge: AnyEdge | null,
strict?: boolean
-) => ValidationResult;
+) => string | null;
const getEqualityPredicate =
(c: Connection) =>
@@ -45,12 +35,16 @@ const getTargetEqualityPredicate =
return e.target === c.target && e.targetHandle === c.targetHandle;
};
-export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
-export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
-
-export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
+export const validateConnection: ValidateConnectionFunc = (
+ c,
+ nodes,
+ edges,
+ templates,
+ ignoreEdge,
+ strict = true
+): string | null => {
if (c.source === c.target) {
- return buildRejectResult('nodes.cannotConnectToSelf');
+ return 'nodes.cannotConnectToSelf';
}
if (strict) {
@@ -65,66 +59,66 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
- return buildRejectResult('nodes.cannotDuplicateConnection');
+ return 'nodes.cannotDuplicateConnection';
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
- return buildRejectResult('nodes.missingNode');
+ return 'nodes.missingNode';
}
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
- return buildRejectResult('nodes.missingNode');
+ return 'nodes.missingNode';
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
- return buildRejectResult('nodes.missingInvocationTemplate');
+ return 'nodes.missingInvocationTemplate';
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
- return buildRejectResult('nodes.missingInvocationTemplate');
+ return 'nodes.missingInvocationTemplate';
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
- return buildRejectResult('nodes.missingFieldTemplate');
+ return 'nodes.missingFieldTemplate';
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
- return buildRejectResult('nodes.missingFieldTemplate');
+ return 'nodes.missingFieldTemplate';
}
if (targetFieldTemplate.input === 'direct') {
- return buildRejectResult('nodes.cannotConnectToDirectInput');
+ return 'nodes.cannotConnectToDirectInput';
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
- return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
+ return 'nodes.cannotMixAndMatchCollectionItemTypes';
}
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
// CollectionItemField inputs can have multiple input connections
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
- return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
+ return 'nodes.inputMayOnlyHaveOneConnection';
}
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
- return buildRejectResult('nodes.fieldTypesMustMatch');
+ return 'nodes.fieldTypesMustMatch';
}
}
if (getHasCycles(c.source, c.target, nodes, edges)) {
- return buildRejectResult('nodes.connectionWouldCreateCycle');
+ return 'nodes.connectionWouldCreateCycle';
}
- return buildAcceptResult();
+ return null;
};