perf(ui): optimize field handle/title/etc rendering

This commit is contained in:
psychedelicious
2025-02-16 11:37:26 +10:00
parent 726b4637db
commit cd785ba64b
15 changed files with 330 additions and 154 deletions

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate';
@@ -20,6 +21,18 @@ type Props = {
isOpen: boolean;
};
const sx: SystemStyleObject = {
flexDirection: 'column',
w: 'full',
h: 'full',
py: 2,
gap: 1,
borderBottomRadius: 'base',
'&[data-with-footer="true"]': {
borderBottomRadius: 0,
},
};
const InvocationNode = ({ nodeId, isOpen }: Props) => {
const withFooter = useWithFooter(nodeId);
@@ -28,15 +41,7 @@ const InvocationNode = ({ nodeId, isOpen }: Props) => {
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} />
{isOpen && (
<>
<Flex
layerStyle="nodeBody"
flexDirection="column"
w="full"
h="full"
py={2}
gap={1}
borderBottomRadius={withFooter ? 0 : 'base'}
>
<Flex layerStyle="nodeBody" sx={sx} data-with-footer={withFooter}>
<Flex flexDir="column" px={2} w="full" h="full">
<Grid gridTemplateColumns="1fr auto" gridAutoRows="1fr">
<ConnectionFields nodeId={nodeId} />

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex } from '@invoke-ai/ui-library';
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
@@ -13,18 +14,22 @@ type Props = {
isOpen: boolean;
};
const sx: SystemStyleObject = {
borderTopRadius: 'base',
alignItems: 'center',
justifyContent: 'space-between',
h: 8,
textAlign: 'center',
color: 'base.200',
borderBottomRadius: 'base',
'&[data-is-open="true"]': {
borderBottomRadius: 0,
},
};
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
return (
<Flex
layerStyle="nodeHeader"
borderTopRadius="base"
borderBottomRadius={isOpen ? 0 : 'base'}
alignItems="center"
justifyContent="space-between"
h={8}
textAlign="center"
color="base.200"
>
<Flex layerStyle="nodeHeader" sx={sx} data-is-open={isOpen}>
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
<InvocationNodeClassificationIcon nodeId={nodeId} />
<NodeTitle nodeId={nodeId} />

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Flex, FormControl, Spacer } from '@invoke-ai/ui-library';
import { InputFieldDescriptionPopover } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldDescriptionPopover';
import { InputFieldHandle } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle';
@@ -6,6 +7,7 @@ import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import { memo, useCallback, useRef, useState } from 'react';
import { InputFieldRenderer } from './InputFieldRenderer';
@@ -19,11 +21,72 @@ interface Props {
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
if (fieldTemplate.input === 'connection' || isConnected) {
return (
<ConnectedOrConnectionField
nodeId={nodeId}
fieldName={fieldName}
isInvalid={isInvalid}
isConnected={isConnected}
fieldTemplate={fieldTemplate}
/>
);
}
return (
<DirectField
nodeId={nodeId}
fieldName={fieldName}
isInvalid={isInvalid}
isConnected={isConnected}
fieldTemplate={fieldTemplate}
/>
);
});
InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
type CommonProps = {
nodeId: string;
fieldName: string;
isInvalid: boolean;
isConnected: boolean;
fieldTemplate: FieldInputTemplate;
};
const ConnectedOrConnectionField = memo(({ nodeId, fieldName, isInvalid, isConnected }: CommonProps) => {
return (
<InputFieldWrapper>
<FormControl isInvalid={isInvalid} isDisabled={isConnected} px={2}>
<InputFieldTitle nodeId={nodeId} fieldName={fieldName} isInvalid={isInvalid} />
</FormControl>
<InputFieldHandle nodeId={nodeId} fieldName={fieldName} />
</InputFieldWrapper>
);
});
ConnectedOrConnectionField.displayName = 'ConnectedOrConnectionField';
const directFieldSx: SystemStyleObject = {
orientation: 'vertical',
px: 2,
'&[data-is-dragging="true"]': {
opacity: 0.3,
},
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
pointerEvents: 'auto',
'&[data-is-connected="true"]': {
pointerEvents: 'none',
},
};
const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate }: CommonProps) => {
const draggableRef = useRef<HTMLDivElement>(null);
const dragHandleRef = useRef<HTMLDivElement>(null);
const [isHovered, setIsHovered] = useState(false);
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const onMouseEnter = useCallback(() => {
setIsHovered(true);
@@ -35,30 +98,15 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
if (fieldTemplate.input === 'connection' || isConnected) {
return (
<InputFieldWrapper>
<FormControl isInvalid={isInvalid} isDisabled={isConnected} px={2}>
<InputFieldTitle nodeId={nodeId} fieldName={fieldName} isInvalid={isInvalid} />
</FormControl>
<InputFieldHandle nodeId={nodeId} fieldName={fieldName} />
</InputFieldWrapper>
);
}
return (
<InputFieldWrapper>
<FormControl
ref={draggableRef}
isInvalid={isInvalid}
isDisabled={isConnected}
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
pointerEvents={isConnected ? 'none' : 'auto'}
orientation="vertical"
px={2}
opacity={isDragging ? 0.3 : 1}
sx={directFieldSx}
data-is-connected={isConnected}
data-is-dragging={isDragging}
>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex className="nodrag" ref={dragHandleRef} gap={1}>
@@ -79,5 +127,4 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
</InputFieldWrapper>
);
});
InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
DirectField.displayName = 'DirectField';

View File

@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
import { Handle, Position } from '@xyflow/react';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
useConnectionValidationResult,
useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldInputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -60,33 +61,60 @@ const handleStyles = {
} satisfies CSSProperties;
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const isConnectionInProgress = useIsConnectionInProgress();
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
const tooltip = useMemo(() => {
if (isConnectionInProgress && validationResult.messageTKey) {
return t(validationResult.messageTKey);
}
return fieldTypeName;
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
if (isConnectionInProgress) {
return (
<ConnectionInProgressHandle
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
fieldTypeName={fieldTypeName}
fieldColor={fieldColor}
isModelField={isModelField}
/>
);
}
return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<IdleHandle
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
fieldTypeName={fieldTypeName}
fieldColor={fieldColor}
isModelField={isModelField}
/>
);
});
InputFieldHandle.displayName = 'InputFieldHandle';
type HandleCommonProps = {
nodeId: string;
fieldName: string;
fieldTemplate: FieldInputTemplate;
fieldTypeName: string;
fieldColor: string;
isModelField: boolean;
};
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
return (
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
data-is-connection-in-progress={isConnectionInProgress}
data-is-connection-start-field={isConnectionStartField}
data-is-connection-valid={validationResult.isValid}
data-is-connection-in-progress={false}
data-is-connection-start-field={false}
data-is-connection-valid={false}
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
borderColor={fieldColor}
/>
@@ -94,5 +122,38 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
</Tooltip>
);
});
IdleHandle.displayName = 'IdleHandle';
InputFieldHandle.displayName = 'InputFieldHandle';
const ConnectionInProgressHandle = memo(
({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
const { t } = useTranslation();
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
const tooltip = useMemo(() => {
if (connectionError !== null) {
return t(connectionError);
}
return fieldTypeName;
}, [fieldTypeName, t, connectionError]);
return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
data-is-connection-in-progress={true}
data-is-connection-start-field={isConnectionStartField}
data-is-connection-valid={connectionError === null}
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
borderColor={fieldColor}
/>
</Handle>
</Tooltip>
);
}
);
ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';

View File

@@ -4,7 +4,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
import { useEditable } from 'common/hooks/useEditable';
import { InputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldTooltipContent';
import {
useConnectionValidationResult,
useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
@@ -47,7 +47,7 @@ export const InputFieldTitle = memo((props: Props) => {
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const isConnectionInProgress = useIsConnectionInProgress();
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
const dispatch = useAppDispatch();
const defaultTitle = useMemo(() => fieldTemplateTitle || t('nodes.unknownField'), [fieldTemplateTitle, t]);
@@ -76,7 +76,7 @@ export const InputFieldTitle = memo((props: Props) => {
noOfLines={1}
data-is-invalid={isInvalid}
data-is-disabled={
(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected
(isConnectionInProgress && connectionError !== null && !isConnectionStartField) || isConnected
}
onDoubleClick={editable.startEditing}
>

View File

@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
import { Handle, Position } from '@xyflow/react';
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
import {
useConnectionValidationResult,
useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
import type { FieldOutputTemplate } from 'features/nodes/types/field';
import type { CSSProperties } from 'react';
import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
@@ -60,33 +61,60 @@ const handleStyles = {
} satisfies CSSProperties;
export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
const { t } = useTranslation();
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
const isConnectionInProgress = useIsConnectionInProgress();
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
const tooltip = useMemo(() => {
if (isConnectionInProgress && validationResult.messageTKey) {
return t(validationResult.messageTKey);
}
return fieldTypeName;
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
if (isConnectionInProgress) {
return (
<ConnectionInProgressHandle
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
fieldTypeName={fieldTypeName}
fieldColor={fieldColor}
isModelField={isModelField}
/>
);
}
return (
<Tooltip label={tooltip} placement="end" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<IdleHandle
nodeId={nodeId}
fieldName={fieldName}
fieldTemplate={fieldTemplate}
fieldTypeName={fieldTypeName}
fieldColor={fieldColor}
isModelField={isModelField}
/>
);
});
OutputFieldHandle.displayName = 'OutputFieldHandle';
type HandleCommonProps = {
nodeId: string;
fieldName: string;
fieldTemplate: FieldOutputTemplate;
fieldTypeName: string;
fieldColor: string;
isModelField: boolean;
};
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
return (
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
data-is-connection-in-progress={isConnectionInProgress}
data-is-connection-start-field={isConnectionStartField}
data-is-connection-valid={validationResult.isValid}
data-is-connection-in-progress={false}
data-is-connection-start-field={false}
data-is-connection-valid={false}
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
borderColor={fieldColor}
/>
@@ -94,5 +122,38 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
</Tooltip>
);
});
IdleHandle.displayName = 'IdleHandle';
OutputFieldHandle.displayName = 'OutputFieldHandle';
const ConnectionInProgressHandle = memo(
({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
const { t } = useTranslation();
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
const tooltip = useMemo(() => {
if (connectionErrorTKey !== null) {
return t(connectionErrorTKey);
}
return fieldTypeName;
}, [fieldTypeName, t, connectionErrorTKey]);
return (
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
<Box
sx={sx}
data-cardinality={fieldTemplate.type.cardinality}
data-is-batch-field={fieldTemplate.type.batch}
data-is-model-field={isModelField}
data-is-connection-in-progress={true}
data-is-connection-start-field={isConnectionStartField}
data-is-connection-valid={connectionErrorTKey === null}
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
borderColor={fieldColor}
/>
</Handle>
</Tooltip>
);
}
);
ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';

View File

@@ -2,7 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Text, Tooltip } from '@invoke-ai/ui-library';
import { OutputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTooltipContent';
import {
useConnectionValidationResult,
useConnectionErrorTKey,
useIsConnectionInProgress,
useIsConnectionStartField,
} from 'features/nodes/hooks/useFieldConnectionState';
@@ -31,7 +31,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
const isConnectionInProgress = useIsConnectionInProgress();
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'source');
return (
<Tooltip
@@ -41,7 +41,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
>
<Text
data-is-disabled={
(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected
(isConnectionInProgress && connectionErrorTKey !== null && !isConnectionStartField) || isConnected
}
sx={sx}
>

View File

@@ -1,3 +1,4 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Icon, IconButton } from '@invoke-ai/ui-library';
import { useUpdateNodeInternals } from '@xyflow/react';
import { useAppDispatch } from 'app/store/storeHooks';
@@ -10,6 +11,15 @@ interface Props {
isOpen: boolean;
}
const iconSx: SystemStyleObject = {
transitionProperty: 'transform',
transitionDuration: 'normal',
transform: 'rotate(180deg)',
'&[data-is-open="true"]': {
transform: 'rotate(0deg)',
},
};
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
const dispatch = useAppDispatch();
const updateNodeInternals = useUpdateNodeInternals();
@@ -28,14 +38,7 @@ const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
w={8}
h={8}
variant="link"
icon={
<Icon
as={PiCaretUpBold}
transform={isOpen ? 'rotate(0deg)' : 'rotate(180deg)'}
transitionProperty="transform"
transitionDuration="normal"
/>
}
icon={<Icon as={PiCaretUpBold} sx={iconSx} data-is-open={isOpen} />}
/>
);
};

View File

@@ -10,18 +10,18 @@ import {
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
import { useMemo } from 'react';
export const useConnectionValidationResult = (nodeId: string, fieldName: string, handleType: HandleType) => {
export const useConnectionErrorTKey = (nodeId: string, fieldName: string, handleType: HandleType): string | null => {
const pendingConnection = useStore($pendingConnection);
const templates = useStore($templates);
const edgePendingUpdate = useStore($edgePendingUpdate);
const selectValidationResult = useMemo(
const selectConnectionError = useMemo(
() => makeConnectionErrorSelector(templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate),
[templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate]
);
const validationResult = useAppSelector(selectValidationResult);
return validationResult;
const connectionError = useAppSelector(selectConnectionError);
return connectionError;
};
export const useIsConnectionStartField = (nodeId: string, fieldName: string, handleType: HandleType) => {

View File

@@ -21,7 +21,7 @@ export const useIsValidConnection = (): IsValidConnection<AnyEdge> => {
const edgePendingUpdate = $edgePendingUpdate.get();
const { nodes, edges } = selectNodesSlice(store.getState());
const validationResult = validateConnection(
const connectionErrorTKey = validateConnection(
{ source, sourceHandle, target, targetHandle },
nodes,
edges,
@@ -30,7 +30,7 @@ export const useIsValidConnection = (): IsValidConnection<AnyEdge> => {
shouldValidateGraph
);
return validationResult.isValid;
return connectionErrorTKey === null;
},
[templates, shouldValidateGraph, store]
);

View File

@@ -126,7 +126,7 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
assert(!isNil(targetHandle));
// Validate the edge before adding it
const validationResult = validateConnection(
const connectionErrorTKey = validateConnection(
{ source, sourceHandle, target, targetHandle },
validationNodes,
validationEdges,
@@ -135,10 +135,10 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
true
);
// If the edge is invalid, log a warning and skip it
if (!validationResult.isValid) {
if (connectionErrorTKey !== null) {
log.warn(
{ edge: { source, sourceHandle, target, targetHandle } },
`Invalid edge, cannot paste: ${t(validationResult.messageTKey)}`
`Invalid edge, cannot paste: ${t(connectionErrorTKey)}`
);
return;
}

View File

@@ -105,8 +105,8 @@ export const getTargetCandidateFields = (
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
const c = { source, sourceHandle, target, targetHandle: field.name };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return connectionErrorTKey === null;
});
return targetCandidateFields;
@@ -141,8 +141,8 @@ export const getSourceCandidateFields = (
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
const c = { source, sourceHandle: field.name, target, targetHandle };
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return r.isValid;
const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
return connectionErrorTKey === null;
});
return sourceCandidateFields;

View File

@@ -1,8 +1,8 @@
import { createSelector } from '@reduxjs/toolkit';
import type { HandleType } from '@xyflow/react';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { AnyEdge } from 'features/nodes/types/invocation';
/**
@@ -25,18 +25,18 @@ export const makeConnectionErrorSelector = (
pendingConnection: PendingConnection | null,
edgePendingUpdate: AnyEdge | null
) => {
return createMemoizedSelector(selectNodesSlice, (nodesSlice: NodesState) => {
return createSelector(selectNodesSlice, (nodesSlice: NodesState): string | null => {
const { nodes, edges } = nodesSlice;
if (!pendingConnection) {
return buildRejectResult('nodes.noConnectionInProgress');
return 'nodes.noConnectionInProgress';
}
if (handleType === pendingConnection.handleType) {
if (handleType === 'source') {
return buildRejectResult('nodes.cannotConnectOutputToOutput');
return 'nodes.cannotConnectOutputToOutput';
}
return buildRejectResult('nodes.cannotConnectInputToInput');
return 'nodes.cannotConnectInputToInput';
}
// we have to figure out which is the target and which is the source

View File

@@ -3,13 +3,13 @@ import { set } from 'lodash-es';
import { describe, expect, it } from 'vitest';
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
import { validateConnection } from './validateConnection';
describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
expect(r).toEqual('nodes.cannotConnectToSelf');
});
describe('missing nodes', () => {
@@ -19,12 +19,12 @@ describe(validateConnection.name, () => {
it('should reject missing source node', () => {
const r = validateConnection(c, [n2], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
expect(r).toEqual('nodes.missingNode');
});
it('should reject missing target node', () => {
const r = validateConnection(c, [n1], [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
expect(r).toEqual('nodes.missingNode');
});
});
@@ -36,12 +36,12 @@ describe(validateConnection.name, () => {
it('should reject missing source template', () => {
const r = validateConnection(c, nodes, [], { sub }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
expect(r).toEqual('nodes.missingInvocationTemplate');
});
it('should reject missing target template', () => {
const r = validateConnection(c, nodes, [], { add }, null);
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
expect(r).toEqual('nodes.missingInvocationTemplate');
});
});
@@ -53,13 +53,13 @@ describe(validateConnection.name, () => {
it('should reject missing source field template', () => {
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
expect(r).toEqual('nodes.missingFieldTemplate');
});
it('should reject missing target field template', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
expect(r).toEqual('nodes.missingFieldTemplate');
});
});
@@ -69,19 +69,19 @@ describe(validateConnection.name, () => {
it('should accept non-duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], templates, null);
expect(r).toEqual(buildAcceptResult());
expect(r).toEqual(null);
});
it('should reject duplicate connections', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
expect(r).toEqual('nodes.cannotDuplicateConnection');
});
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const e = buildEdge(n1.id, 'value', n2.id, 'a');
const r = validateConnection(c, [n1, n2], [e], templates, e);
expect(r).toEqual(buildAcceptResult());
expect(r).toEqual(null);
});
});
@@ -95,7 +95,7 @@ describe(validateConnection.name, () => {
const n2 = buildNode(addWithDirectAField);
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
expect(r).toEqual('nodes.cannotConnectToDirectInput');
});
it('should reject connection to a collect node with mismatched item types', () => {
@@ -107,7 +107,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes');
});
it('should accept connection to a collect node with matching item types', () => {
@@ -119,7 +119,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildAcceptResult());
expect(r).toEqual(null);
});
it('should reject connections to target field that is already connected', () => {
@@ -131,7 +131,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection');
});
it('should accept connections to target field that is already connected (ignored edge)', () => {
@@ -143,7 +143,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, e1);
expect(r).toEqual(buildAcceptResult());
expect(r).toEqual(null);
});
it('should reject connections between invalid types', () => {
@@ -152,7 +152,7 @@ describe(validateConnection.name, () => {
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null);
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should reject connections that would create cycles', () => {
@@ -163,14 +163,14 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
describe('non-strict mode', () => {
it('should reject connections from self to self in non-strict mode', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
const r = validateConnection(c, [], [], templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
expect(r).toEqual('nodes.cannotConnectToSelf');
});
it('should reject connections that create cycles in non-strict mode', () => {
const n1 = buildNode(add);
@@ -180,7 +180,7 @@ describe(validateConnection.name, () => {
const edges = [e1];
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
const r = validateConnection(c, nodes, edges, templates, null, false);
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
it('should otherwise allow invalid connections in non-strict mode', () => {
const n1 = buildNode(add);
@@ -188,7 +188,7 @@ describe(validateConnection.name, () => {
const nodes = [n1, n2];
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
const r = validateConnection(c, nodes, [], templates, null, false);
expect(r).toEqual(buildAcceptResult());
expect(r).toEqual(null);
});
});
});

View File

@@ -9,16 +9,6 @@ import type { SetNonNullable } from 'type-fest';
type Connection = SetNonNullable<NullableConnection>;
type ValidationResult =
| {
isValid: true;
messageTKey?: string;
}
| {
isValid: false;
messageTKey: string;
};
type ValidateConnectionFunc = (
connection: Connection,
nodes: AnyNode[],
@@ -26,7 +16,7 @@ type ValidateConnectionFunc = (
templates: Templates,
ignoreEdge: AnyEdge | null,
strict?: boolean
) => ValidationResult;
) => string | null;
const getEqualityPredicate =
(c: Connection) =>
@@ -45,12 +35,16 @@ const getTargetEqualityPredicate =
return e.target === c.target && e.targetHandle === c.targetHandle;
};
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
export const validateConnection: ValidateConnectionFunc = (
c,
nodes,
edges,
templates,
ignoreEdge,
strict = true
): string | null => {
if (c.source === c.target) {
return buildRejectResult('nodes.cannotConnectToSelf');
return 'nodes.cannotConnectToSelf';
}
if (strict) {
@@ -65,66 +59,66 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
if (filteredEdges.some(getEqualityPredicate(c))) {
// We already have a connection from this source to this target
return buildRejectResult('nodes.cannotDuplicateConnection');
return 'nodes.cannotDuplicateConnection';
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return buildRejectResult('nodes.missingNode');
return 'nodes.missingNode';
}
const targetNode = nodes.find((n) => n.id === c.target);
if (!targetNode) {
return buildRejectResult('nodes.missingNode');
return 'nodes.missingNode';
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
return 'nodes.missingInvocationTemplate';
}
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return buildRejectResult('nodes.missingInvocationTemplate');
return 'nodes.missingInvocationTemplate';
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
return 'nodes.missingFieldTemplate';
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return buildRejectResult('nodes.missingFieldTemplate');
return 'nodes.missingFieldTemplate';
}
if (targetFieldTemplate.input === 'direct') {
return buildRejectResult('nodes.cannotConnectToDirectInput');
return 'nodes.cannotConnectToDirectInput';
}
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
return 'nodes.cannotMixAndMatchCollectionItemTypes';
}
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
// CollectionItemField inputs can have multiple input connections
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
return 'nodes.inputMayOnlyHaveOneConnection';
}
}
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
return buildRejectResult('nodes.fieldTypesMustMatch');
return 'nodes.fieldTypesMustMatch';
}
}
if (getHasCycles(c.source, c.target, nodes, edges)) {
return buildRejectResult('nodes.connectionWouldCreateCycle');
return 'nodes.connectionWouldCreateCycle';
}
return buildAcceptResult();
return null;
};