mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-01 15:54:58 -05:00
perf(ui): optimize field handle/title/etc rendering
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, Grid, GridItem } from '@invoke-ai/ui-library';
|
||||
import { InputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldGate';
|
||||
import { OutputFieldGate } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldGate';
|
||||
@@ -20,6 +21,18 @@ type Props = {
|
||||
isOpen: boolean;
|
||||
};
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
flexDirection: 'column',
|
||||
w: 'full',
|
||||
h: 'full',
|
||||
py: 2,
|
||||
gap: 1,
|
||||
borderBottomRadius: 'base',
|
||||
'&[data-with-footer="true"]': {
|
||||
borderBottomRadius: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const InvocationNode = ({ nodeId, isOpen }: Props) => {
|
||||
const withFooter = useWithFooter(nodeId);
|
||||
|
||||
@@ -28,15 +41,7 @@ const InvocationNode = ({ nodeId, isOpen }: Props) => {
|
||||
<InvocationNodeHeader nodeId={nodeId} isOpen={isOpen} />
|
||||
{isOpen && (
|
||||
<>
|
||||
<Flex
|
||||
layerStyle="nodeBody"
|
||||
flexDirection="column"
|
||||
w="full"
|
||||
h="full"
|
||||
py={2}
|
||||
gap={1}
|
||||
borderBottomRadius={withFooter ? 0 : 'base'}
|
||||
>
|
||||
<Flex layerStyle="nodeBody" sx={sx} data-with-footer={withFooter}>
|
||||
<Flex flexDir="column" px={2} w="full" h="full">
|
||||
<Grid gridTemplateColumns="1fr auto" gridAutoRows="1fr">
|
||||
<ConnectionFields nodeId={nodeId} />
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex } from '@invoke-ai/ui-library';
|
||||
import NodeCollapseButton from 'features/nodes/components/flow/nodes/common/NodeCollapseButton';
|
||||
import NodeTitle from 'features/nodes/components/flow/nodes/common/NodeTitle';
|
||||
@@ -13,18 +14,22 @@ type Props = {
|
||||
isOpen: boolean;
|
||||
};
|
||||
|
||||
const sx: SystemStyleObject = {
|
||||
borderTopRadius: 'base',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'space-between',
|
||||
h: 8,
|
||||
textAlign: 'center',
|
||||
color: 'base.200',
|
||||
borderBottomRadius: 'base',
|
||||
'&[data-is-open="true"]': {
|
||||
borderBottomRadius: 0,
|
||||
},
|
||||
};
|
||||
|
||||
const InvocationNodeHeader = ({ nodeId, isOpen }: Props) => {
|
||||
return (
|
||||
<Flex
|
||||
layerStyle="nodeHeader"
|
||||
borderTopRadius="base"
|
||||
borderBottomRadius={isOpen ? 0 : 'base'}
|
||||
alignItems="center"
|
||||
justifyContent="space-between"
|
||||
h={8}
|
||||
textAlign="center"
|
||||
color="base.200"
|
||||
>
|
||||
<Flex layerStyle="nodeHeader" sx={sx} data-is-open={isOpen}>
|
||||
<NodeCollapseButton nodeId={nodeId} isOpen={isOpen} />
|
||||
<InvocationNodeClassificationIcon nodeId={nodeId} />
|
||||
<NodeTitle nodeId={nodeId} />
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Flex, FormControl, Spacer } from '@invoke-ai/ui-library';
|
||||
import { InputFieldDescriptionPopover } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldDescriptionPopover';
|
||||
import { InputFieldHandle } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldHandle';
|
||||
@@ -6,6 +7,7 @@ import { useNodeFieldDnd } from 'features/nodes/components/sidePanel/builder/dnd
|
||||
import { useInputFieldIsConnected } from 'features/nodes/hooks/useInputFieldIsConnected';
|
||||
import { useInputFieldIsInvalid } from 'features/nodes/hooks/useInputFieldIsInvalid';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { memo, useCallback, useRef, useState } from 'react';
|
||||
|
||||
import { InputFieldRenderer } from './InputFieldRenderer';
|
||||
@@ -19,11 +21,72 @@ interface Props {
|
||||
|
||||
export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
|
||||
if (fieldTemplate.input === 'connection' || isConnected) {
|
||||
return (
|
||||
<ConnectedOrConnectionField
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
isInvalid={isInvalid}
|
||||
isConnected={isConnected}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<DirectField
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
isInvalid={isInvalid}
|
||||
isConnected={isConnected}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
|
||||
|
||||
type CommonProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isInvalid: boolean;
|
||||
isConnected: boolean;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
};
|
||||
|
||||
const ConnectedOrConnectionField = memo(({ nodeId, fieldName, isInvalid, isConnected }: CommonProps) => {
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
<FormControl isInvalid={isInvalid} isDisabled={isConnected} px={2}>
|
||||
<InputFieldTitle nodeId={nodeId} fieldName={fieldName} isInvalid={isInvalid} />
|
||||
</FormControl>
|
||||
<InputFieldHandle nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
});
|
||||
ConnectedOrConnectionField.displayName = 'ConnectedOrConnectionField';
|
||||
|
||||
const directFieldSx: SystemStyleObject = {
|
||||
orientation: 'vertical',
|
||||
px: 2,
|
||||
'&[data-is-dragging="true"]': {
|
||||
opacity: 0.3,
|
||||
},
|
||||
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
|
||||
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
|
||||
pointerEvents: 'auto',
|
||||
'&[data-is-connected="true"]': {
|
||||
pointerEvents: 'none',
|
||||
},
|
||||
};
|
||||
|
||||
const DirectField = memo(({ nodeId, fieldName, isInvalid, isConnected, fieldTemplate }: CommonProps) => {
|
||||
const draggableRef = useRef<HTMLDivElement>(null);
|
||||
const dragHandleRef = useRef<HTMLDivElement>(null);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const isInvalid = useInputFieldIsInvalid(nodeId, fieldName);
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
|
||||
const onMouseEnter = useCallback(() => {
|
||||
setIsHovered(true);
|
||||
@@ -35,30 +98,15 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
|
||||
|
||||
const isDragging = useNodeFieldDnd({ nodeId, fieldName }, fieldTemplate, draggableRef, dragHandleRef);
|
||||
|
||||
if (fieldTemplate.input === 'connection' || isConnected) {
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
<FormControl isInvalid={isInvalid} isDisabled={isConnected} px={2}>
|
||||
<InputFieldTitle nodeId={nodeId} fieldName={fieldName} isInvalid={isInvalid} />
|
||||
</FormControl>
|
||||
|
||||
<InputFieldHandle nodeId={nodeId} fieldName={fieldName} />
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<InputFieldWrapper>
|
||||
<FormControl
|
||||
ref={draggableRef}
|
||||
isInvalid={isInvalid}
|
||||
isDisabled={isConnected}
|
||||
// Without pointerEvents prop, disabled inputs don't trigger reactflow events. For example, when making a
|
||||
// connection, the mouse up to end the connection won't fire, leaving the connection in-progress.
|
||||
pointerEvents={isConnected ? 'none' : 'auto'}
|
||||
orientation="vertical"
|
||||
px={2}
|
||||
opacity={isDragging ? 0.3 : 1}
|
||||
sx={directFieldSx}
|
||||
data-is-connected={isConnected}
|
||||
data-is-dragging={isDragging}
|
||||
>
|
||||
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Flex className="nodrag" ref={dragHandleRef} gap={1}>
|
||||
@@ -79,5 +127,4 @@ export const InputFieldEditModeNodes = memo(({ nodeId, fieldName }: Props) => {
|
||||
</InputFieldWrapper>
|
||||
);
|
||||
});
|
||||
|
||||
InputFieldEditModeNodes.displayName = 'InputFieldEditModeNodes';
|
||||
DirectField.displayName = 'DirectField';
|
||||
|
||||
@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import {
|
||||
useConnectionValidationResult,
|
||||
useConnectionErrorTKey,
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useInputFieldTemplate } from 'features/nodes/hooks/useInputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import type { FieldInputTemplate } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -60,33 +61,60 @@ const handleStyles = {
|
||||
} satisfies CSSProperties;
|
||||
|
||||
export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useInputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && validationResult.messageTKey) {
|
||||
return t(validationResult.messageTKey);
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
|
||||
if (isConnectionInProgress) {
|
||||
return (
|
||||
<ConnectionInProgressHandle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
fieldTypeName={fieldTypeName}
|
||||
fieldColor={fieldColor}
|
||||
isModelField={isModelField}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<IdleHandle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
fieldTypeName={fieldTypeName}
|
||||
fieldColor={fieldColor}
|
||||
isModelField={isModelField}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
InputFieldHandle.displayName = 'InputFieldHandle';
|
||||
|
||||
type HandleCommonProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldInputTemplate;
|
||||
fieldTypeName: string;
|
||||
fieldColor: string;
|
||||
isModelField: boolean;
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
data-is-batch-field={fieldTemplate.type.batch}
|
||||
data-is-model-field={isModelField}
|
||||
data-is-connection-in-progress={isConnectionInProgress}
|
||||
data-is-connection-start-field={isConnectionStartField}
|
||||
data-is-connection-valid={validationResult.isValid}
|
||||
data-is-connection-in-progress={false}
|
||||
data-is-connection-start-field={false}
|
||||
data-is-connection-valid={false}
|
||||
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
|
||||
borderColor={fieldColor}
|
||||
/>
|
||||
@@ -94,5 +122,38 @@ export const InputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
IdleHandle.displayName = 'IdleHandle';
|
||||
|
||||
InputFieldHandle.displayName = 'InputFieldHandle';
|
||||
const ConnectionInProgressHandle = memo(
|
||||
({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionError !== null) {
|
||||
return t(connectionError);
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [fieldTypeName, t, connectionError]);
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="target" id={fieldTemplate.name} position={Position.Left} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
data-is-batch-field={fieldTemplate.type.batch}
|
||||
data-is-model-field={isModelField}
|
||||
data-is-connection-in-progress={true}
|
||||
data-is-connection-start-field={isConnectionStartField}
|
||||
data-is-connection-valid={connectionError === null}
|
||||
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
|
||||
borderColor={fieldColor}
|
||||
/>
|
||||
</Handle>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
);
|
||||
ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';
|
||||
|
||||
@@ -4,7 +4,7 @@ import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useEditable } from 'common/hooks/useEditable';
|
||||
import { InputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/InputFieldTooltipContent';
|
||||
import {
|
||||
useConnectionValidationResult,
|
||||
useConnectionErrorTKey,
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
@@ -47,7 +47,7 @@ export const InputFieldTitle = memo((props: Props) => {
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'target');
|
||||
const connectionError = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
|
||||
const dispatch = useAppDispatch();
|
||||
const defaultTitle = useMemo(() => fieldTemplateTitle || t('nodes.unknownField'), [fieldTemplateTitle, t]);
|
||||
@@ -76,7 +76,7 @@ export const InputFieldTitle = memo((props: Props) => {
|
||||
noOfLines={1}
|
||||
data-is-invalid={isInvalid}
|
||||
data-is-disabled={
|
||||
(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected
|
||||
(isConnectionInProgress && connectionError !== null && !isConnectionStartField) || isConnected
|
||||
}
|
||||
onDoubleClick={editable.startEditing}
|
||||
>
|
||||
|
||||
@@ -3,13 +3,14 @@ import { Box, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import { getFieldColor } from 'features/nodes/components/flow/edges/util/getEdgeColor';
|
||||
import {
|
||||
useConnectionValidationResult,
|
||||
useConnectionErrorTKey,
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
import { useOutputFieldTemplate } from 'features/nodes/hooks/useOutputFieldTemplate';
|
||||
import { useFieldTypeName } from 'features/nodes/hooks/usePrettyFieldType';
|
||||
import { HANDLE_TOOLTIP_OPEN_DELAY, MODEL_TYPES } from 'features/nodes/types/constants';
|
||||
import type { FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -60,33 +61,60 @@ const handleStyles = {
|
||||
} satisfies CSSProperties;
|
||||
|
||||
export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const fieldTemplate = useOutputFieldTemplate(nodeId, fieldName);
|
||||
const fieldTypeName = useFieldTypeName(fieldTemplate.type);
|
||||
const fieldColor = useMemo(() => getFieldColor(fieldTemplate.type), [fieldTemplate.type]);
|
||||
const isModelField = useMemo(() => MODEL_TYPES.some((t) => t === fieldTemplate.type.name), [fieldTemplate.type]);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (isConnectionInProgress && validationResult.messageTKey) {
|
||||
return t(validationResult.messageTKey);
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [fieldTypeName, isConnectionInProgress, t, validationResult.messageTKey]);
|
||||
if (isConnectionInProgress) {
|
||||
return (
|
||||
<ConnectionInProgressHandle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
fieldTypeName={fieldTypeName}
|
||||
fieldColor={fieldColor}
|
||||
isModelField={isModelField}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="end" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<IdleHandle
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
fieldTemplate={fieldTemplate}
|
||||
fieldTypeName={fieldTypeName}
|
||||
fieldColor={fieldColor}
|
||||
isModelField={isModelField}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
OutputFieldHandle.displayName = 'OutputFieldHandle';
|
||||
|
||||
type HandleCommonProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
fieldTemplate: FieldOutputTemplate;
|
||||
fieldTypeName: string;
|
||||
fieldColor: string;
|
||||
isModelField: boolean;
|
||||
};
|
||||
|
||||
const IdleHandle = memo(({ fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
return (
|
||||
<Tooltip label={fieldTypeName} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
data-is-batch-field={fieldTemplate.type.batch}
|
||||
data-is-model-field={isModelField}
|
||||
data-is-connection-in-progress={isConnectionInProgress}
|
||||
data-is-connection-start-field={isConnectionStartField}
|
||||
data-is-connection-valid={validationResult.isValid}
|
||||
data-is-connection-in-progress={false}
|
||||
data-is-connection-start-field={false}
|
||||
data-is-connection-valid={false}
|
||||
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
|
||||
borderColor={fieldColor}
|
||||
/>
|
||||
@@ -94,5 +122,38 @@ export const OutputFieldHandle = memo(({ nodeId, fieldName }: Props) => {
|
||||
</Tooltip>
|
||||
);
|
||||
});
|
||||
IdleHandle.displayName = 'IdleHandle';
|
||||
|
||||
OutputFieldHandle.displayName = 'OutputFieldHandle';
|
||||
const ConnectionInProgressHandle = memo(
|
||||
({ nodeId, fieldName, fieldTemplate, fieldTypeName, fieldColor, isModelField }: HandleCommonProps) => {
|
||||
const { t } = useTranslation();
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'target');
|
||||
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'target');
|
||||
|
||||
const tooltip = useMemo(() => {
|
||||
if (connectionErrorTKey !== null) {
|
||||
return t(connectionErrorTKey);
|
||||
}
|
||||
return fieldTypeName;
|
||||
}, [fieldTypeName, t, connectionErrorTKey]);
|
||||
|
||||
return (
|
||||
<Tooltip label={tooltip} placement="start" openDelay={HANDLE_TOOLTIP_OPEN_DELAY}>
|
||||
<Handle type="source" id={fieldTemplate.name} position={Position.Right} style={handleStyles}>
|
||||
<Box
|
||||
sx={sx}
|
||||
data-cardinality={fieldTemplate.type.cardinality}
|
||||
data-is-batch-field={fieldTemplate.type.batch}
|
||||
data-is-model-field={isModelField}
|
||||
data-is-connection-in-progress={true}
|
||||
data-is-connection-start-field={isConnectionStartField}
|
||||
data-is-connection-valid={connectionErrorTKey === null}
|
||||
backgroundColor={fieldTemplate.type.cardinality === 'SINGLE' ? fieldColor : 'base.900'}
|
||||
borderColor={fieldColor}
|
||||
/>
|
||||
</Handle>
|
||||
</Tooltip>
|
||||
);
|
||||
}
|
||||
);
|
||||
ConnectionInProgressHandle.displayName = 'ConnectionInProgressHandle';
|
||||
|
||||
@@ -2,7 +2,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { OutputFieldTooltipContent } from 'features/nodes/components/flow/nodes/Invocation/fields/OutputFieldTooltipContent';
|
||||
import {
|
||||
useConnectionValidationResult,
|
||||
useConnectionErrorTKey,
|
||||
useIsConnectionInProgress,
|
||||
useIsConnectionStartField,
|
||||
} from 'features/nodes/hooks/useFieldConnectionState';
|
||||
@@ -31,7 +31,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
|
||||
const isConnected = useInputFieldIsConnected(nodeId, fieldName);
|
||||
const isConnectionStartField = useIsConnectionStartField(nodeId, fieldName, 'source');
|
||||
const isConnectionInProgress = useIsConnectionInProgress();
|
||||
const validationResult = useConnectionValidationResult(nodeId, fieldName, 'source');
|
||||
const connectionErrorTKey = useConnectionErrorTKey(nodeId, fieldName, 'source');
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
@@ -41,7 +41,7 @@ export const OutputFieldTitle = memo(({ nodeId, fieldName }: Props) => {
|
||||
>
|
||||
<Text
|
||||
data-is-disabled={
|
||||
(isConnectionInProgress && !validationResult.isValid && !isConnectionStartField) || isConnected
|
||||
(isConnectionInProgress && connectionErrorTKey !== null && !isConnectionStartField) || isConnected
|
||||
}
|
||||
sx={sx}
|
||||
>
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Icon, IconButton } from '@invoke-ai/ui-library';
|
||||
import { useUpdateNodeInternals } from '@xyflow/react';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
@@ -10,6 +11,15 @@ interface Props {
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
const iconSx: SystemStyleObject = {
|
||||
transitionProperty: 'transform',
|
||||
transitionDuration: 'normal',
|
||||
transform: 'rotate(180deg)',
|
||||
'&[data-is-open="true"]': {
|
||||
transform: 'rotate(0deg)',
|
||||
},
|
||||
};
|
||||
|
||||
const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
@@ -28,14 +38,7 @@ const NodeCollapseButton = ({ nodeId, isOpen }: Props) => {
|
||||
w={8}
|
||||
h={8}
|
||||
variant="link"
|
||||
icon={
|
||||
<Icon
|
||||
as={PiCaretUpBold}
|
||||
transform={isOpen ? 'rotate(0deg)' : 'rotate(180deg)'}
|
||||
transitionProperty="transform"
|
||||
transitionDuration="normal"
|
||||
/>
|
||||
}
|
||||
icon={<Icon as={PiCaretUpBold} sx={iconSx} data-is-open={isOpen} />}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -10,18 +10,18 @@ import {
|
||||
import { makeConnectionErrorSelector } from 'features/nodes/store/util/makeConnectionErrorSelector';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useConnectionValidationResult = (nodeId: string, fieldName: string, handleType: HandleType) => {
|
||||
export const useConnectionErrorTKey = (nodeId: string, fieldName: string, handleType: HandleType): string | null => {
|
||||
const pendingConnection = useStore($pendingConnection);
|
||||
const templates = useStore($templates);
|
||||
const edgePendingUpdate = useStore($edgePendingUpdate);
|
||||
|
||||
const selectValidationResult = useMemo(
|
||||
const selectConnectionError = useMemo(
|
||||
() => makeConnectionErrorSelector(templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate),
|
||||
[templates, nodeId, fieldName, handleType, pendingConnection, edgePendingUpdate]
|
||||
);
|
||||
|
||||
const validationResult = useAppSelector(selectValidationResult);
|
||||
return validationResult;
|
||||
const connectionError = useAppSelector(selectConnectionError);
|
||||
return connectionError;
|
||||
};
|
||||
|
||||
export const useIsConnectionStartField = (nodeId: string, fieldName: string, handleType: HandleType) => {
|
||||
|
||||
@@ -21,7 +21,7 @@ export const useIsValidConnection = (): IsValidConnection<AnyEdge> => {
|
||||
const edgePendingUpdate = $edgePendingUpdate.get();
|
||||
const { nodes, edges } = selectNodesSlice(store.getState());
|
||||
|
||||
const validationResult = validateConnection(
|
||||
const connectionErrorTKey = validateConnection(
|
||||
{ source, sourceHandle, target, targetHandle },
|
||||
nodes,
|
||||
edges,
|
||||
@@ -30,7 +30,7 @@ export const useIsValidConnection = (): IsValidConnection<AnyEdge> => {
|
||||
shouldValidateGraph
|
||||
);
|
||||
|
||||
return validationResult.isValid;
|
||||
return connectionErrorTKey === null;
|
||||
},
|
||||
[templates, shouldValidateGraph, store]
|
||||
);
|
||||
|
||||
@@ -126,7 +126,7 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
assert(!isNil(targetHandle));
|
||||
|
||||
// Validate the edge before adding it
|
||||
const validationResult = validateConnection(
|
||||
const connectionErrorTKey = validateConnection(
|
||||
{ source, sourceHandle, target, targetHandle },
|
||||
validationNodes,
|
||||
validationEdges,
|
||||
@@ -135,10 +135,10 @@ const _pasteSelection = (withEdgesToCopiedNodes?: boolean) => {
|
||||
true
|
||||
);
|
||||
// If the edge is invalid, log a warning and skip it
|
||||
if (!validationResult.isValid) {
|
||||
if (connectionErrorTKey !== null) {
|
||||
log.warn(
|
||||
{ edge: { source, sourceHandle, target, targetHandle } },
|
||||
`Invalid edge, cannot paste: ${t(validationResult.messageTKey)}`
|
||||
`Invalid edge, cannot paste: ${t(connectionErrorTKey)}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -105,8 +105,8 @@ export const getTargetCandidateFields = (
|
||||
|
||||
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
|
||||
const c = { source, sourceHandle, target, targetHandle: field.name };
|
||||
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return r.isValid;
|
||||
const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return connectionErrorTKey === null;
|
||||
});
|
||||
|
||||
return targetCandidateFields;
|
||||
@@ -141,8 +141,8 @@ export const getSourceCandidateFields = (
|
||||
|
||||
const sourceCandidateFields = map(sourceTemplate.outputs).filter((field) => {
|
||||
const c = { source, sourceHandle: field.name, target, targetHandle };
|
||||
const r = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return r.isValid;
|
||||
const connectionErrorTKey = validateConnection(c, nodes, edges, templates, edgePendingUpdate, true);
|
||||
return connectionErrorTKey === null;
|
||||
});
|
||||
|
||||
return sourceCandidateFields;
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import type { HandleType } from '@xyflow/react';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState, PendingConnection, Templates } from 'features/nodes/store/types';
|
||||
import { buildRejectResult, validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import type { AnyEdge } from 'features/nodes/types/invocation';
|
||||
|
||||
/**
|
||||
@@ -25,18 +25,18 @@ export const makeConnectionErrorSelector = (
|
||||
pendingConnection: PendingConnection | null,
|
||||
edgePendingUpdate: AnyEdge | null
|
||||
) => {
|
||||
return createMemoizedSelector(selectNodesSlice, (nodesSlice: NodesState) => {
|
||||
return createSelector(selectNodesSlice, (nodesSlice: NodesState): string | null => {
|
||||
const { nodes, edges } = nodesSlice;
|
||||
|
||||
if (!pendingConnection) {
|
||||
return buildRejectResult('nodes.noConnectionInProgress');
|
||||
return 'nodes.noConnectionInProgress';
|
||||
}
|
||||
|
||||
if (handleType === pendingConnection.handleType) {
|
||||
if (handleType === 'source') {
|
||||
return buildRejectResult('nodes.cannotConnectOutputToOutput');
|
||||
return 'nodes.cannotConnectOutputToOutput';
|
||||
}
|
||||
return buildRejectResult('nodes.cannotConnectInputToInput');
|
||||
return 'nodes.cannotConnectInputToInput';
|
||||
}
|
||||
|
||||
// we have to figure out which is the target and which is the source
|
||||
|
||||
@@ -3,13 +3,13 @@ import { set } from 'lodash-es';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
|
||||
import { buildAcceptResult, buildRejectResult, validateConnection } from './validateConnection';
|
||||
import { validateConnection } from './validateConnection';
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
it('should reject invalid connection to self', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
expect(r).toEqual('nodes.cannotConnectToSelf');
|
||||
});
|
||||
|
||||
describe('missing nodes', () => {
|
||||
@@ -19,12 +19,12 @@ describe(validateConnection.name, () => {
|
||||
|
||||
it('should reject missing source node', () => {
|
||||
const r = validateConnection(c, [n2], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
expect(r).toEqual('nodes.missingNode');
|
||||
});
|
||||
|
||||
it('should reject missing target node', () => {
|
||||
const r = validateConnection(c, [n1], [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingNode'));
|
||||
expect(r).toEqual('nodes.missingNode');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -36,12 +36,12 @@ describe(validateConnection.name, () => {
|
||||
|
||||
it('should reject missing source template', () => {
|
||||
const r = validateConnection(c, nodes, [], { sub }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
expect(r).toEqual('nodes.missingInvocationTemplate');
|
||||
});
|
||||
|
||||
it('should reject missing target template', () => {
|
||||
const r = validateConnection(c, nodes, [], { add }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingInvocationTemplate'));
|
||||
expect(r).toEqual('nodes.missingInvocationTemplate');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -53,13 +53,13 @@ describe(validateConnection.name, () => {
|
||||
it('should reject missing source field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'invalid', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
expect(r).toEqual('nodes.missingFieldTemplate');
|
||||
});
|
||||
|
||||
it('should reject missing target field template', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'invalid' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.missingFieldTemplate'));
|
||||
expect(r).toEqual('nodes.missingFieldTemplate');
|
||||
});
|
||||
});
|
||||
|
||||
@@ -69,19 +69,19 @@ describe(validateConnection.name, () => {
|
||||
it('should accept non-duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
it('should reject duplicate connections', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotDuplicateConnection'));
|
||||
expect(r).toEqual('nodes.cannotDuplicateConnection');
|
||||
});
|
||||
it('should accept duplicate connections if the duplicate is an ignored edge', () => {
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const e = buildEdge(n1.id, 'value', n2.id, 'a');
|
||||
const r = validateConnection(c, [n1, n2], [e], templates, e);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -95,7 +95,7 @@ describe(validateConnection.name, () => {
|
||||
const n2 = buildNode(addWithDirectAField);
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, [n1, n2], [], { add, addWithDirectAField }, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToDirectInput'));
|
||||
expect(r).toEqual('nodes.cannotConnectToDirectInput');
|
||||
});
|
||||
|
||||
it('should reject connection to a collect node with mismatched item types', () => {
|
||||
@@ -107,7 +107,7 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'vae', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes'));
|
||||
expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
});
|
||||
|
||||
it('should accept connection to a collect node with matching item types', () => {
|
||||
@@ -119,7 +119,7 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'item' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject connections to target field that is already connected', () => {
|
||||
@@ -131,7 +131,7 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.inputMayOnlyHaveOneConnection'));
|
||||
expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection');
|
||||
});
|
||||
|
||||
it('should accept connections to target field that is already connected (ignored edge)', () => {
|
||||
@@ -143,7 +143,7 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n3.id, sourceHandle: 'value', target: n2.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, e1);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject connections between invalid types', () => {
|
||||
@@ -152,7 +152,7 @@ describe(validateConnection.name, () => {
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.fieldTypesMustMatch'));
|
||||
expect(r).toEqual('nodes.fieldTypesMustMatch');
|
||||
});
|
||||
|
||||
it('should reject connections that would create cycles', () => {
|
||||
@@ -163,14 +163,14 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
expect(r).toEqual('nodes.connectionWouldCreateCycle');
|
||||
});
|
||||
|
||||
describe('non-strict mode', () => {
|
||||
it('should reject connections from self to self in non-strict mode', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
const r = validateConnection(c, [], [], templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.cannotConnectToSelf'));
|
||||
expect(r).toEqual('nodes.cannotConnectToSelf');
|
||||
});
|
||||
it('should reject connections that create cycles in non-strict mode', () => {
|
||||
const n1 = buildNode(add);
|
||||
@@ -180,7 +180,7 @@ describe(validateConnection.name, () => {
|
||||
const edges = [e1];
|
||||
const c = { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' };
|
||||
const r = validateConnection(c, nodes, edges, templates, null, false);
|
||||
expect(r).toEqual(buildRejectResult('nodes.connectionWouldCreateCycle'));
|
||||
expect(r).toEqual('nodes.connectionWouldCreateCycle');
|
||||
});
|
||||
it('should otherwise allow invalid connections in non-strict mode', () => {
|
||||
const n1 = buildNode(add);
|
||||
@@ -188,7 +188,7 @@ describe(validateConnection.name, () => {
|
||||
const nodes = [n1, n2];
|
||||
const c = { source: n1.id, sourceHandle: 'value', target: n2.id, targetHandle: 'image' };
|
||||
const r = validateConnection(c, nodes, [], templates, null, false);
|
||||
expect(r).toEqual(buildAcceptResult());
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,16 +9,6 @@ import type { SetNonNullable } from 'type-fest';
|
||||
|
||||
type Connection = SetNonNullable<NullableConnection>;
|
||||
|
||||
type ValidationResult =
|
||||
| {
|
||||
isValid: true;
|
||||
messageTKey?: string;
|
||||
}
|
||||
| {
|
||||
isValid: false;
|
||||
messageTKey: string;
|
||||
};
|
||||
|
||||
type ValidateConnectionFunc = (
|
||||
connection: Connection,
|
||||
nodes: AnyNode[],
|
||||
@@ -26,7 +16,7 @@ type ValidateConnectionFunc = (
|
||||
templates: Templates,
|
||||
ignoreEdge: AnyEdge | null,
|
||||
strict?: boolean
|
||||
) => ValidationResult;
|
||||
) => string | null;
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
@@ -45,12 +35,16 @@ const getTargetEqualityPredicate =
|
||||
return e.target === c.target && e.targetHandle === c.targetHandle;
|
||||
};
|
||||
|
||||
export const buildAcceptResult = (): ValidationResult => ({ isValid: true });
|
||||
export const buildRejectResult = (messageTKey: string): ValidationResult => ({ isValid: false, messageTKey });
|
||||
|
||||
export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, templates, ignoreEdge, strict = true) => {
|
||||
export const validateConnection: ValidateConnectionFunc = (
|
||||
c,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
ignoreEdge,
|
||||
strict = true
|
||||
): string | null => {
|
||||
if (c.source === c.target) {
|
||||
return buildRejectResult('nodes.cannotConnectToSelf');
|
||||
return 'nodes.cannotConnectToSelf';
|
||||
}
|
||||
|
||||
if (strict) {
|
||||
@@ -65,66 +59,66 @@ export const validateConnection: ValidateConnectionFunc = (c, nodes, edges, temp
|
||||
|
||||
if (filteredEdges.some(getEqualityPredicate(c))) {
|
||||
// We already have a connection from this source to this target
|
||||
return buildRejectResult('nodes.cannotDuplicateConnection');
|
||||
return 'nodes.cannotDuplicateConnection';
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
if (!targetNode) {
|
||||
return buildRejectResult('nodes.missingNode');
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return buildRejectResult('nodes.missingInvocationTemplate');
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return buildRejectResult('nodes.missingFieldTemplate');
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
if (targetFieldTemplate.input === 'direct') {
|
||||
return buildRejectResult('nodes.cannotConnectToDirectInput');
|
||||
return 'nodes.cannotConnectToDirectInput';
|
||||
}
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types.
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return buildRejectResult('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
return 'nodes.cannotMixAndMatchCollectionItemTypes';
|
||||
}
|
||||
}
|
||||
|
||||
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
|
||||
// CollectionItemField inputs can have multiple input connections
|
||||
if (targetFieldTemplate.type.name !== 'CollectionItemField') {
|
||||
return buildRejectResult('nodes.inputMayOnlyHaveOneConnection');
|
||||
return 'nodes.inputMayOnlyHaveOneConnection';
|
||||
}
|
||||
}
|
||||
|
||||
if (!validateConnectionTypes(sourceFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return buildRejectResult('nodes.fieldTypesMustMatch');
|
||||
return 'nodes.fieldTypesMustMatch';
|
||||
}
|
||||
}
|
||||
|
||||
if (getHasCycles(c.source, c.target, nodes, edges)) {
|
||||
return buildRejectResult('nodes.connectionWouldCreateCycle');
|
||||
return 'nodes.connectionWouldCreateCycle';
|
||||
}
|
||||
|
||||
return buildAcceptResult();
|
||||
return null;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user