Merge branch 'main' into external-models

This commit is contained in:
Alexander Eichhorn
2026-04-14 02:09:56 +02:00
committed by GitHub
28 changed files with 2260 additions and 128 deletions

View File

@@ -1449,6 +1449,8 @@
"notes": "Notes",
"description": "Description",
"notesDescription": "Add notes about your workflow",
"addConnector": "Add Connector",
"deleteConnector": "Delete Connector",
"problemSettingTitle": "Problem Setting Title",
"resetToDefaultValue": "Reset to default value",
"reloadNodeTemplates": "Reload Node Templates",

View File

@@ -1,4 +1,4 @@
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { Menu, MenuButton, MenuItem, MenuList, Portal, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import type {
EdgeChange,
@@ -32,7 +32,9 @@ import {
$edgePendingUpdate,
$lastEdgeUpdateMouseEvent,
$pendingConnection,
$templates,
$viewport,
connectorInserted,
edgesChanged,
nodesChanged,
redo,
@@ -46,18 +48,24 @@ import {
selectNodes,
selectNodesSlice,
} from 'features/nodes/store/selectors';
import { getConnectorDeletionSpliceConnections } from 'features/nodes/store/util/connectorTopology';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
import type { CSSProperties, MouseEvent, RefObject } from 'react';
import { memo, useCallback, useMemo, useRef } from 'react';
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
import { useHotkeys } from 'react-hotkeys-hook';
import { useTranslation } from 'react-i18next';
import { PiPlugsConnectedBold, PiTrashBold } from 'react-icons/pi';
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
import InvocationDefaultEdge from './edges/InvocationDefaultEdge';
import ConnectorNode from './nodes/Connector/ConnectorNode';
import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
import NotesNode from './nodes/Notes/NotesNode';
@@ -70,6 +78,7 @@ const edgeTypes = {
const nodeTypes = {
invocation: InvocationNodeWrapper,
connector: ConnectorNode,
current_image: CurrentImageNode,
notes: NotesNode,
} as const;
@@ -81,17 +90,69 @@ const snapGrid: [number, number] = [25, 25];
const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection;
type WorkflowContextMenuState =
| {
kind: 'pane';
clientX: number;
clientY: number;
pageX: number;
pageY: number;
}
| {
kind: 'connector';
connectorId: string;
pageX: number;
pageY: number;
}
| null;
const getWorkflowContextMenuState = (
event: globalThis.MouseEvent,
flowWrapper: HTMLDivElement | null
): WorkflowContextMenuState => {
if (event.shiftKey || !(event.target instanceof Element) || !flowWrapper?.contains(event.target)) {
return null;
}
const connectorId = event.target.closest<HTMLElement>('[data-connector-node-id]')?.dataset.connectorNodeId;
if (connectorId) {
return {
kind: 'connector',
connectorId,
pageX: event.pageX,
pageY: event.pageY,
};
}
const paneTarget = event.target.closest('.react-flow__pane');
if (paneTarget && flowWrapper.contains(paneTarget)) {
return {
kind: 'pane',
clientX: event.clientX,
clientY: event.clientY,
pageX: event.pageX,
pageY: event.pageY,
};
}
return null;
};
export const Flow = memo(() => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const nodes = useAppSelector(selectNodes);
const edges = useAppSelector(selectEdges);
const templates = useStore($templates);
const viewport = useStore($viewport);
const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid);
const selectionMode = useAppSelector(selectSelectionMode);
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
const flowWrapper = useRef<HTMLDivElement>(null);
const pendingNodeInternalsUpdateRef = useRef<string[] | null>(null);
const isValidConnection = useIsValidConnection();
const updateNodeInternals = useUpdateNodeInternals();
const [contextMenuState, setContextMenuState] = useState<WorkflowContextMenuState>(null);
useFocusRegion('workflows', flowWrapper);
@@ -127,9 +188,16 @@ export const Flow = memo(() => {
}, []);
const { onCloseGlobal } = useGlobalMenuClose();
const handlePaneClick = useCallback(() => {
onCloseGlobal();
}, [onCloseGlobal]);
const handlePaneClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onPaneClick']> = useCallback(
(event) => {
if ('button' in event && event.button !== 0) {
return;
}
onCloseGlobal();
setContextMenuState(null);
},
[onCloseGlobal]
);
const onInit: OnInit<AnyNode, AnyEdge> = useCallback((flow) => {
$flow.set(flow);
@@ -147,6 +215,149 @@ export const Flow = memo(() => {
}
}, []);
useEffect(() => {
const pendingNodeIds = pendingNodeInternalsUpdateRef.current;
if (!pendingNodeIds) {
return;
}
pendingNodeInternalsUpdateRef.current = null;
const frameId = requestAnimationFrame(() => {
updateNodeInternals([...new Set(pendingNodeIds)]);
});
return () => {
cancelAnimationFrame(frameId);
};
}, [edges, nodes, updateNodeInternals]);
const addConnectorAtPaneMenuPosition = useCallback(() => {
if (contextMenuState?.kind !== 'pane') {
return;
}
const flow = $flow.get();
if (!flow) {
return;
}
const connector = buildConnectorNode(
flow.screenToFlowPosition({
x: contextMenuState.clientX,
y: contextMenuState.clientY,
})
);
dispatch(nodesChanged([{ type: 'add', item: connector }]));
setContextMenuState(null);
}, [contextMenuState, dispatch]);
const connectorSpliceConnections = useMemo(
() =>
contextMenuState?.kind === 'connector'
? getConnectorDeletionSpliceConnections(
contextMenuState.connectorId,
nodes,
edges,
templates,
validateConnection
)
: null,
[contextMenuState, edges, nodes, templates]
);
const deleteConnectorFromContextMenu = useCallback(() => {
if (contextMenuState?.kind !== 'connector' || !connectorSpliceConnections) {
return;
}
const connectorEdgeRemovals: EdgeChange<AnyEdge>[] = edges
.filter((edge) => edge.source === contextMenuState.connectorId || edge.target === contextMenuState.connectorId)
.map((edge) => ({ type: 'remove', id: edge.id }));
const spliceEdgeAdditions: EdgeChange<AnyEdge>[] = connectorSpliceConnections.map((connection) => ({
type: 'add',
item: connectionToEdge(connection),
}));
pendingNodeInternalsUpdateRef.current = [
contextMenuState.connectorId,
...connectorSpliceConnections.flatMap((connection) => [connection.source, connection.target]),
];
dispatch(edgesChanged([...connectorEdgeRemovals, ...spliceEdgeAdditions]));
dispatch(nodesChanged([{ type: 'remove', id: contextMenuState.connectorId }]));
setContextMenuState(null);
}, [connectorSpliceConnections, contextMenuState, dispatch, edges]);
useEffect(() => {
const onWindowContextMenu = (event: globalThis.MouseEvent) => {
const nextContextMenuState = getWorkflowContextMenuState(event, flowWrapper.current);
if (!nextContextMenuState) {
return;
}
event.preventDefault();
event.stopPropagation();
setContextMenuState(nextContextMenuState);
};
window.addEventListener('contextmenu', onWindowContextMenu, { capture: true });
return () => {
window.removeEventListener('contextmenu', onWindowContextMenu, { capture: true });
};
}, []);
const renderContextMenu = useCallback(() => {
if (contextMenuState?.kind === 'pane') {
return (
<MenuList visibility="visible">
<MenuItem icon={<PiPlugsConnectedBold />} onClick={addConnectorAtPaneMenuPosition}>
{t('nodes.addConnector')}
</MenuItem>
</MenuList>
);
}
if (contextMenuState?.kind === 'connector') {
return (
<MenuList visibility="visible">
<MenuItem
icon={<PiTrashBold />}
onClick={deleteConnectorFromContextMenu}
isDisabled={!connectorSpliceConnections}
isDestructive
>
{t('nodes.deleteConnector')}
</MenuItem>
</MenuList>
);
}
return <MenuList visibility="visible" />;
}, [addConnectorAtPaneMenuPosition, connectorSpliceConnections, contextMenuState, deleteConnectorFromContextMenu, t]);
const closeContextMenu = useCallback(() => {
setContextMenuState(null);
}, []);
const onEdgeDoubleClick = useCallback<NonNullable<ReactFlowProps['onEdgeDoubleClick']>>(
(event, edge) => {
if (edge.type !== 'default' || edge.hidden) {
return;
}
const flow = $flow.get();
if (!flow) {
return;
}
const connector = buildConnectorNode(
flow.screenToFlowPosition({
x: event.clientX,
y: event.clientY,
})
);
dispatch(connectorInserted({ edgeId: edge.id, connector }));
updateNodeInternals([edge.source, edge.target, connector.id]);
},
[dispatch, updateNodeInternals]
);
// #region Updatable Edges
/**
@@ -209,16 +420,130 @@ export const Flow = memo(() => {
// #endregion
const renderedNodes = useMemo(() => nodes, [nodes]);
const renderedEdges = useMemo(() => edges, [edges]);
const contextMenuPosition = contextMenuState ? { x: contextMenuState.pageX, y: contextMenuState.pageY } : null;
const contextMenuKey = contextMenuPosition ? `${contextMenuPosition.x}-${contextMenuPosition.y}` : 'closed';
return (
<>
<FlowSurface
flowWrapper={flowWrapper}
viewport={viewport}
renderedNodes={renderedNodes}
renderedEdges={renderedEdges}
onInit={onInit}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
onEdgesChange={onEdgesChange}
onReconnect={onReconnect}
onReconnectStart={onReconnectStart}
onReconnectEnd={onReconnectEnd}
onConnectStart={onConnectStart}
onConnect={onConnect}
onConnectEnd={onConnectEnd}
handleMoveEnd={handleMoveEnd}
onEdgeDoubleClick={onEdgeDoubleClick}
isValidConnection={isValidConnection}
shouldSnapToGrid={shouldSnapToGrid}
flowStyles={flowStyles}
handlePaneClick={handlePaneClick}
selectionMode={selectionMode}
/>
<Portal>
<Menu
key={contextMenuKey}
isOpen={contextMenuPosition !== null}
onClose={closeContextMenu}
placement="auto-start"
gutter={0}
>
<MenuButton
aria-hidden
position="absolute"
left={contextMenuPosition?.x ?? -9999}
top={contextMenuPosition?.y ?? -9999}
w={1}
h={1}
pointerEvents="none"
bg="transparent"
/>
{renderContextMenu()}
</Menu>
</Portal>
<HotkeyIsolator flowWrapper={flowWrapper} />
</>
);
});
Flow.displayName = 'Flow';
type FlowSurfaceProps = {
flowWrapper: { current: HTMLDivElement | null };
viewport: ReactFlowProps<AnyNode, AnyEdge>['defaultViewport'];
renderedNodes: AnyNode[];
renderedEdges: AnyEdge[];
onInit: OnInit<AnyNode, AnyEdge>;
onMouseMove: (event: MouseEvent<HTMLDivElement>) => void;
onNodesChange: OnNodesChange<AnyNode>;
onEdgesChange: OnEdgesChange<AnyEdge>;
onReconnect: OnReconnect;
onReconnectStart: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onReconnectStart']>;
onReconnectEnd: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onReconnectEnd']>;
onConnectStart: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnectStart']>;
onConnect: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnect']>;
onConnectEnd: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnectEnd']>;
handleMoveEnd: OnMoveEnd;
onEdgeDoubleClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onEdgeDoubleClick']>;
isValidConnection: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['isValidConnection']>;
shouldSnapToGrid: boolean;
flowStyles: CSSProperties;
handlePaneClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onPaneClick']>;
selectionMode: ReturnType<typeof selectSelectionMode>;
};
const FlowSurface = memo((props: FlowSurfaceProps) => {
const {
flowWrapper,
viewport,
renderedNodes,
renderedEdges,
onInit,
onMouseMove,
onNodesChange,
onEdgesChange,
onReconnect,
onReconnectStart,
onReconnectEnd,
onConnectStart,
onConnect,
onConnectEnd,
handleMoveEnd,
onEdgeDoubleClick,
isValidConnection,
shouldSnapToGrid,
flowStyles,
handlePaneClick,
selectionMode,
} = props;
const setFlowWrapperElement = useCallback(
(el: HTMLDivElement | null) => {
flowWrapper.current = el;
},
[flowWrapper]
);
return (
<div ref={setFlowWrapperElement} style={{ width: '100%', height: '100%' }}>
<ReactFlow<AnyNode, AnyEdge>
id="workflow-editor"
ref={flowWrapper}
defaultViewport={viewport}
nodeTypes={nodeTypes}
edgeTypes={edgeTypes}
nodes={nodes}
edges={edges}
nodes={renderedNodes}
edges={renderedEdges}
onInit={onInit}
onMouseMove={onMouseMove}
onNodesChange={onNodesChange}
@@ -230,6 +555,7 @@ export const Flow = memo(() => {
onConnect={onConnect}
onConnectEnd={onConnectEnd}
onMoveEnd={handleMoveEnd}
onEdgeDoubleClick={onEdgeDoubleClick}
connectionLineComponent={CustomConnectionLine}
isValidConnection={isValidConnection}
minZoom={0.1}
@@ -249,12 +575,11 @@ export const Flow = memo(() => {
>
<Background gap={snapGrid} offset={snapGrid} />
</ReactFlow>
<HotkeyIsolator flowWrapper={flowWrapper} />
</>
</div>
);
});
Flow.displayName = 'Flow';
FlowSurface.displayName = 'FlowSurface';
const HotkeyIsolator = memo(({ flowWrapper }: { flowWrapper: RefObject<HTMLDivElement> }) => {
const mayUndo = useAppSelector(selectMayUndo);

View File

@@ -1,9 +1,10 @@
import { createSelector } from '@reduxjs/toolkit';
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
import { selectNodes } from 'features/nodes/store/selectors';
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isConnectorNode } from 'features/nodes/types/invocation';
import { getFieldColor } from './getEdgeColor';
@@ -22,23 +23,20 @@ export const buildSelectEdgeColor = (
target: string,
targetHandleId: string | null | undefined
) =>
createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => {
createSelector(selectNodes, selectEdges, selectWorkflowSettingsSlice, (nodes, edges, workflowSettings): string => {
const { shouldColorEdges } = workflowSettings;
if (!shouldColorEdges) {
return colorTokenToCssVar('base.500');
}
const sourceNode = nodes.find((node) => node.id === source);
const targetNode = nodes.find((node) => node.id === target);
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
if (!sourceNode || !sourceHandleId || !targetHandleId) {
return colorTokenToCssVar('base.500');
}
const sourceNodeTemplate = templates[sourceNode.data.type];
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
const sourceType = isConnectorNode(sourceNode)
? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates)
: templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type;
return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
});
@@ -50,7 +48,7 @@ export const buildSelectEdgeLabel = (
target: string,
targetHandleId: string | null | undefined
) =>
createSelector(selectNodes, (nodes): string | null => {
createSelector(selectNodes, selectEdges, (nodes, edges): string | null => {
const sourceNode = nodes.find((node) => node.id === source);
const targetNode = nodes.find((node) => node.id === target);
@@ -58,8 +56,12 @@ export const buildSelectEdgeLabel = (
return null;
}
const sourceNodeTemplate = templates[sourceNode.data.type];
const resolvedSource = isConnectorNode(sourceNode) ? resolveConnectorSource(sourceNode.id, nodes, edges) : null;
const sourceTemplate =
resolvedSource !== null
? templates[nodes.find((node) => node.id === resolvedSource.nodeId)?.data.type ?? '']
: templates[sourceNode.data.type];
const targetNodeTemplate = templates[targetNode.data.type];
return `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
return `${sourceTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
});

View File

@@ -0,0 +1,97 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Icon } from '@invoke-ai/ui-library';
import type { Node, NodeProps } from '@xyflow/react';
import { Handle, Position } from '@xyflow/react';
import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
import type { ConnectorNodeData } from 'features/nodes/types/invocation';
import type { CSSProperties } from 'react';
import { memo } from 'react';
import { PiDotOutlineFill } from 'react-icons/pi';
const handleVisualSx = {
w: 3,
h: 3,
borderRadius: 'full',
borderWidth: 2,
borderColor: 'base.900',
bg: 'base.100',
pointerEvents: 'none',
} satisfies SystemStyleObject;
const handleStyles = {
position: 'absolute',
width: '1rem',
height: '1rem',
display: 'flex',
alignItems: 'center',
justifyContent: 'center',
top: '50%',
transform: 'translateY(-50%)',
zIndex: 1,
background: 'none',
border: 'none',
} satisfies CSSProperties;
const inputHandleStyles = {
...handleStyles,
insetInlineStart: 0,
justifyContent: 'flex-start',
} satisfies CSSProperties;
const outputHandleStyles = {
...handleStyles,
insetInlineEnd: 0,
justifyContent: 'flex-end',
} satisfies CSSProperties;
const ConnectorNode = ({ id, selected }: NodeProps<Node<ConnectorNodeData>>) => {
return (
<NonInvocationNodeWrapper nodeId={id} selected={selected} width={25} borderRadius="full" withChrome={false}>
<Box
data-connector-node-context-menu="true"
data-connector-node-id={id}
position="relative"
w={25}
h={25}
display="flex"
alignItems="center"
justifyContent="center"
>
<Handle
className={NO_DRAG_CLASS}
type="target"
id={CONNECTOR_INPUT_HANDLE}
position={Position.Left}
style={inputHandleStyles}
>
<Box sx={handleVisualSx} />
</Handle>
<Box
w={13}
h={13}
display="flex"
alignItems="center"
justifyContent="center"
borderRadius="full"
bg={selected ? 'base.650' : 'base.700'}
boxShadow={selected ? '0 0 0 2px var(--invoke-colors-blue-300)' : '0 0 0 1px var(--invoke-colors-base-500)'}
>
<Icon as={PiDotOutlineFill} boxSize={5} color={selected ? 'base.50' : 'base.100'} />
</Box>
<Handle
className={NO_DRAG_CLASS}
type="source"
id={CONNECTOR_OUTPUT_HANDLE}
position={Position.Right}
style={outputHandleStyles}
>
<Box sx={handleVisualSx} />
</Handle>
</Box>
</NonInvocationNodeWrapper>
);
};
export default memo(ConnectorNode);

View File

@@ -16,10 +16,12 @@ type NonInvocationNodeWrapperProps = PropsWithChildren & {
nodeId: string;
selected: boolean;
width?: ChakraProps['w'];
borderRadius?: ChakraProps['borderRadius'];
withChrome?: boolean;
};
const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => {
const { nodeId, width, children, selected } = props;
const { nodeId, width, children, selected, borderRadius = 'base', withChrome = true } = props;
const mouseOverNode = useMouseOverNode(nodeId);
const zoomToNode = useZoomToNode(nodeId);
@@ -62,14 +64,15 @@ const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => {
onMouseOut={mouseOverNode.handleMouseOut}
className={DRAG_HANDLE_CLASSNAME}
sx={containerSx}
borderRadius={borderRadius}
width={width || NODE_WIDTH}
opacity={opacity}
data-is-selected={selected}
>
<Box sx={shadowsSx} />
<Box sx={inProgressSx} data-is-in-progress={isInProgress} />
{withChrome && <Box sx={shadowsSx} />}
{withChrome && <Box sx={inProgressSx} data-is-in-progress={isInProgress} />}
{children}
<Box className="node-selection-overlay" />
{withChrome && <Box className="node-selection-overlay" />}
</Box>
);
};

View File

@@ -6,7 +6,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
export const containerSx: SystemStyleObject = {
h: 'full',
position: 'relative',
borderRadius: 'base',
borderRadius: 'inherit',
transitionProperty: 'none',
cursor: 'grab',
'--border-color': 'var(--invoke-colors-base-500)',
@@ -30,7 +30,7 @@ export const containerSx: SystemStyleObject = {
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
borderRadius: 'inherit',
transitionProperty: 'none',
pointerEvents: 'none',
shadow: '0 0 0 1px var(--border-color)',
@@ -64,7 +64,7 @@ export const shadowsSx: SystemStyleObject = {
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'base',
borderRadius: 'inherit',
pointerEvents: 'none',
zIndex: -1,
shadow: 'var(--invoke-shadows-xl), var(--invoke-shadows-base), var(--invoke-shadows-base)',
@@ -76,7 +76,7 @@ export const inProgressSx: SystemStyleObject = {
insetInlineEnd: 0,
bottom: 0,
insetInlineStart: 0,
borderRadius: 'md',
borderRadius: 'inherit',
pointerEvents: 'none',
transitionProperty: 'none',
opacity: 0.7,

View File

@@ -3,6 +3,7 @@ import { $templates } from 'features/nodes/store/nodesSlice';
import { $flow } from 'features/nodes/store/reactFlowInstance';
import { NODE_WIDTH } from 'features/nodes/types/constants';
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
return useCallback(
// string here is "any invocation type"
(type: string | 'current_image' | 'notes'): AnyNode => {
(type: string | 'connector' | 'current_image' | 'notes'): AnyNode => {
const flow = $flow.get();
assert(flow !== null);
@@ -42,6 +43,10 @@ export const useBuildNode = () => {
return buildNotesNode(position);
}
if (type === 'connector') {
return buildConnectorNode(position);
}
// TODO: Keep track of invocation types so we do not need to cast this
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
const template = templates[type] as InvocationTemplate;

View File

@@ -12,9 +12,15 @@ import {
edgesChanged,
} from 'features/nodes/store/nodesSlice';
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
resolveConnectorSourceFieldType,
} from 'features/nodes/store/util/connectorTopology';
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import type { AnyEdge } from 'features/nodes/types/invocation';
import { isConnectorNode } from 'features/nodes/types/invocation';
import { useCallback, useMemo } from 'react';
import { assert } from 'tsafe';
@@ -33,6 +39,47 @@ export const useConnection = () => {
return;
}
if (isConnectorNode(node)) {
if (handleType === 'source' && handleId !== CONNECTOR_OUTPUT_HANDLE) {
return;
}
if (handleType === 'target' && handleId !== CONNECTOR_INPUT_HANDLE) {
return;
}
const resolvedSourceType =
handleType === 'source'
? resolveConnectorSourceFieldType(nodeId, nodes, selectNodesSlice(store.getState()).edges, templates)
: null;
$pendingConnection.set({
nodeId,
handleId,
handleType,
fieldTemplate:
handleType === 'source'
? {
name: CONNECTOR_OUTPUT_HANDLE,
title: 'Connector Output',
description: '',
fieldKind: 'output',
ui_hidden: false,
type: resolvedSourceType ?? { name: 'AnyField', cardinality: 'SINGLE', batch: false },
}
: {
name: CONNECTOR_INPUT_HANDLE,
title: 'Connector Input',
description: '',
fieldKind: 'input',
input: 'connection',
required: false,
default: undefined,
ui_hidden: false,
type: { name: 'AnyField', cardinality: 'SINGLE', batch: false },
},
});
return;
}
const template = templates[node.data.type];
if (!template) {
return;

View File

@@ -0,0 +1,160 @@
import { deepClone } from 'common/util/deepClone';
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
import { describe, expect, it } from 'vitest';
import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology';
import { add, buildEdge, buildNode, sub } from './util/testUtils';
const buildFixedConnectorNode = (id: string) => {
const connectorNode = buildConnectorNode({ x: 0, y: 0 });
return {
...connectorNode,
id,
data: {
...connectorNode.data,
id,
},
};
};
describe('nodesSlice connector actions', () => {
it('splits a direct edge into source -> connector -> target edges when inserting a connector', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connector = buildFixedConnectorNode('connector-1');
const directEdge = buildEdge(source.id, 'value', target.id, 'a');
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, target];
initialState.edges = [directEdge];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
connectorInserted({
edgeId: directEdge.id,
connector,
})
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id, connector.id]);
expect(nextState.edges).toEqual([
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
]);
});
it('splices connector outputs back to the resolved upstream source when removed', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connector = buildFixedConnectorNode('connector-1');
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, connector, target];
initialState.edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
nodesChanged([{ type: 'remove', id: connector.id }])
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]);
expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]);
});
it('splices one connector source back to multiple downstream targets when removed', () => {
const source = buildNode(add);
const targetA = buildNode(sub);
const targetB = buildNode(sub);
const connector = buildFixedConnectorNode('connector-1');
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, connector, targetA, targetB];
initialState.edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'b'),
];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
nodesChanged([{ type: 'remove', id: connector.id }])
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, targetA.id, targetB.id]);
expect(nextState.edges).toEqual([
buildEdge(source.id, 'value', targetA.id, 'a'),
buildEdge(source.id, 'value', targetB.id, 'b'),
]);
});
it('does not create any edges when removing a connector with no downstream targets', () => {
const source = buildNode(add);
const connector = buildFixedConnectorNode('connector-1');
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, connector];
initialState.edges = [buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
nodesChanged([{ type: 'remove', id: connector.id }])
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id]);
expect(nextState.edges).toEqual([]);
});
it('removes a connector while preserving downstream connector edges in a chained splice case', () => {
const source = buildNode(add);
const connectorA = buildFixedConnectorNode('connector-a');
const connectorB = buildFixedConnectorNode('connector-b');
const target = buildNode(sub);
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, connectorA, connectorB, target];
initialState.edges = [
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
nodesChanged([{ type: 'remove', id: connectorA.id }])
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, connectorB.id, target.id]);
expect(nextState.edges).toHaveLength(2);
expect(nextState.edges).toEqual(
expect.arrayContaining([
buildEdge(source.id, 'value', connectorB.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
])
);
});
it('splices connector edges when the connector is removed through generic node removal', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connector = buildFixedConnectorNode('connector-1');
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
initialState.nodes = [source, connector, target];
initialState.edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
];
const nextState = nodesSliceConfig.slice.reducer(
initialState,
nodesChanged([{ type: 'remove', id: connector.id }])
);
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]);
expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]);
});
});

View File

@@ -20,6 +20,13 @@ import {
reparentElement,
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
import { type NodesState, zNodesState } from 'features/nodes/store/types';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
getConnectorOutputEdges,
resolveConnectorSource,
} from 'features/nodes/store/util/connectorTopology';
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type {
BoardFieldValue,
@@ -65,8 +72,8 @@ import {
zStringGeneratorFieldValue,
zStylePresetFieldValue,
} from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import type { AnyEdge, AnyNode, ConnectorNode } from 'features/nodes/types/invocation';
import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import type {
BuilderForm,
ContainerElement,
@@ -103,7 +110,7 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
tags: '',
notes: '',
exposedFields: [],
meta: { version: '3.0.0', category: 'user' },
meta: { version: '4.0.0', category: 'user' },
form: getDefaultForm(),
nodes: [],
edges: [],
@@ -175,6 +182,33 @@ const slice = createSlice({
initialState: getInitialState(),
reducers: {
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
const removedConnectorSpliceEdges: AnyEdge[] = action.payload.flatMap((change) => {
if (change.type !== 'remove') {
return [];
}
const node = state.nodes.find((candidate) => candidate.id === change.id);
if (!isConnectorNode(node)) {
return [];
}
const resolvedSource = resolveConnectorSource(node.id, state.nodes, state.edges);
if (!resolvedSource) {
return [];
}
return getConnectorOutputEdges(node.id, state.edges)
.filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default')
.map((edge) =>
connectionToEdge({
source: resolvedSource.nodeId,
sourceHandle: resolvedSource.fieldName,
target: edge.target,
targetHandle: edge.targetHandle,
})
);
});
// TODO(psyche): The below TS issue was recently fixed upstream. Need to upgrade @xyflow/react and then we
// should be able to remove this cast.
//
@@ -206,6 +240,12 @@ const slice = createSlice({
if (edgeChanges.length > 0) {
state.edges = applyEdgeChanges<AnyEdge>(edgeChanges, state.edges);
}
if (removedConnectorSpliceEdges.length > 0) {
state.edges = applyEdgeChanges<AnyEdge>(
removedConnectorSpliceEdges.map((edge) => ({ type: 'add', item: edge })),
state.edges
);
}
}
const wereNodesRemoved = action.payload.some((change) => change.type === 'remove' || change.type === 'replace');
@@ -396,11 +436,40 @@ const slice = createSlice({
}
}
},
connectorInserted: (
state,
action: PayloadAction<{
edgeId: string;
connector: ConnectorNode;
}>
) => {
const { edgeId, connector } = action.payload;
const edge = state.edges.find((candidate) => candidate.id === edgeId);
if (!edge || edge.type !== 'default') {
return;
}
state.nodes.push({ ...SHARED_NODE_PROPERTIES, ...connector } as (typeof state.nodes)[number]);
state.edges = state.edges.filter((candidate) => candidate.id !== edgeId);
state.edges.push(
connectionToEdge({
source: edge.source,
sourceHandle: edge.sourceHandle ?? null,
target: connector.id,
targetHandle: CONNECTOR_INPUT_HANDLE,
}),
connectionToEdge({
source: connector.id,
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
target: edge.target,
targetHandle: edge.targetHandle ?? null,
})
);
},
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
const { nodeId, label } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (isInvocationNode(node) || isNotesNode(node)) {
if (isInvocationNode(node) || isNotesNode(node) || isConnectorNode(node)) {
node.data.label = label;
}
},
@@ -614,6 +683,7 @@ export const {
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,
connectorInserted,
nodeLabelChanged,
nodeNotesChanged,
nodesChanged,

View File

@@ -0,0 +1,126 @@
import type { AnyNode, ConnectorNode } from 'features/nodes/types/invocation';
import { describe, expect, it } from 'vitest';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
getConnectorDeletionSpliceConnections,
getConnectorInputEdge,
getConnectorOutputEdges,
resolveConnectorSource,
resolveConnectorSourceFieldType,
} from './connectorTopology';
import { add, buildEdge, buildNode, img_resize, sub, templates } from './testUtils';
const buildConnectorNode = (id: string): ConnectorNode => ({
id,
type: 'connector',
position: { x: 0, y: 0 },
data: {
id,
type: 'connector',
label: 'Connector',
isOpen: true,
},
});
describe('connectorTopology', () => {
it('resolves the effective upstream source through one connector', () => {
const source = buildNode(add);
const connector = buildConnectorNode('connector-1');
const target = buildNode(sub);
const nodes: AnyNode[] = [source, connector, target];
const edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
];
expect(resolveConnectorSource(connector.id, nodes, edges)).toEqual({
nodeId: source.id,
fieldName: 'value',
});
expect(resolveConnectorSourceFieldType(connector.id, nodes, edges, templates)).toEqual(add.outputs.value?.type);
});
it('resolves the effective upstream source through chained connectors', () => {
const source = buildNode(add);
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const nodes: AnyNode[] = [source, connectorA, connectorB];
const edges = [
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
];
expect(resolveConnectorSource(connectorB.id, nodes, edges)).toEqual({
nodeId: source.id,
fieldName: 'value',
});
});
it('returns no source or type for an unresolved connector chain', () => {
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const nodes: AnyNode[] = [connectorA, connectorB];
const edges = [buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE)];
expect(resolveConnectorSource(connectorB.id, nodes, edges)).toBe(null);
expect(resolveConnectorSourceFieldType(connectorB.id, nodes, edges, templates)).toBe(null);
});
it('enumerates multiple outgoing edges for a connector', () => {
const source = buildNode(add);
const connector = buildConnectorNode('connector-1');
const targetA = buildNode(sub);
const targetB = buildNode(img_resize);
const incoming = buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE);
const outgoingA = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a');
const outgoingB = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width');
const edges = [incoming, outgoingA, outgoingB];
expect(getConnectorInputEdge(connector.id, edges)).toEqual(incoming);
expect(getConnectorOutputEdges(connector.id, edges)).toEqual([outgoingA, outgoingB]);
});
it('rejects connector deletion splice-through when any downstream target would be invalid', () => {
const source = buildNode(add);
const connector = buildConnectorNode('connector-1');
const target = buildNode(img_resize);
const nodes: AnyNode[] = [source, connector, target];
const edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'image'),
];
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null);
});
it('builds connector deletion splice-through edges when every downstream target remains valid', () => {
const source = buildNode(add);
const connector = buildConnectorNode('connector-1');
const target = buildNode(sub);
const nodes: AnyNode[] = [source, connector, target];
const edges = [
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
];
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toEqual([
{
source: source.id,
sourceHandle: 'value',
target: target.id,
targetHandle: 'a',
},
]);
});
it('returns no splice-through edges when a connector has downstream targets but no upstream source', () => {
const connector = buildConnectorNode('connector-1');
const target = buildNode(sub);
const nodes: AnyNode[] = [connector, target];
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')];
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null);
});
});

View File

@@ -0,0 +1,228 @@
import type { Templates } from 'features/nodes/store/types';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation';
export const CONNECTOR_INPUT_HANDLE = 'in';
export const CONNECTOR_OUTPUT_HANDLE = 'out';
type ResolvedConnectorSource = {
nodeId: string;
fieldName: string;
};
type SpliceConnection = {
source: string;
sourceHandle: string;
target: string;
targetHandle: string;
};
type SpliceConnectionValidator = (
connection: SpliceConnection,
nodes: AnyNode[],
edges: AnyEdge[],
templates: Templates,
ignoreEdge: AnyEdge | null,
strict?: boolean
) => string | null;
export const getConnectorInputEdge = (connectorId: string, edges: AnyEdge[]): AnyEdge | undefined =>
edges.find(
(edge) =>
edge.type === 'default' &&
edge.target === connectorId &&
edge.targetHandle === CONNECTOR_INPUT_HANDLE &&
typeof edge.sourceHandle === 'string'
);
export const getConnectorOutputEdges = (connectorId: string, edges: AnyEdge[]): AnyEdge[] =>
edges.filter(
(edge) =>
edge.type === 'default' &&
edge.source === connectorId &&
edge.sourceHandle === CONNECTOR_OUTPUT_HANDLE &&
typeof edge.targetHandle === 'string'
);
export const resolveConnectorSource = (
connectorId: string,
nodes: AnyNode[],
edges: AnyEdge[]
): ResolvedConnectorSource | null => {
const visited = new Set<string>();
const resolve = (nodeId: string): ResolvedConnectorSource | null => {
if (visited.has(nodeId)) {
return null;
}
visited.add(nodeId);
const incomingEdge = getConnectorInputEdge(nodeId, edges);
if (!incomingEdge || incomingEdge.type !== 'default') {
return null;
}
if (typeof incomingEdge.sourceHandle !== 'string') {
return null;
}
const sourceNode = nodes.find((node) => node.id === incomingEdge.source);
if (!sourceNode) {
return null;
}
if (isInvocationNode(sourceNode)) {
return { nodeId: sourceNode.id, fieldName: incomingEdge.sourceHandle };
}
if (isConnectorNode(sourceNode)) {
return resolve(sourceNode.id);
}
return null;
};
return resolve(connectorId);
};
export const resolveConnectorSourceFieldType = (
connectorId: string,
nodes: AnyNode[],
edges: AnyEdge[],
templates: Templates
): FieldType | null => {
const resolvedSource = resolveConnectorSource(connectorId, nodes, edges);
if (!resolvedSource) {
return null;
}
const sourceNode = nodes.find((node) => node.id === resolvedSource.nodeId);
if (!sourceNode || !isInvocationNode(sourceNode)) {
return null;
}
const sourceTemplate = templates[sourceNode.data.type];
return sourceTemplate?.outputs[resolvedSource.fieldName]?.type ?? null;
};
export const getConnectorDeletionSpliceConnections = (
connectorId: string,
nodes: AnyNode[],
edges: AnyEdge[],
templates: Templates,
validateConnection?: SpliceConnectionValidator
): SpliceConnection[] | null => {
const resolvedSource = resolveConnectorSource(connectorId, nodes, edges);
if (!resolvedSource) {
return null;
}
const outputEdges = getConnectorOutputEdges(connectorId, edges);
const spliceConnections = outputEdges
.filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default')
.map((edge) => ({
source: resolvedSource.nodeId,
sourceHandle: resolvedSource.fieldName,
target: edge.target,
targetHandle: edge.targetHandle,
}));
const deduped = new Set<string>();
for (const connection of spliceConnections) {
const key = `${connection.source}:${connection.sourceHandle}->${connection.target}:${connection.targetHandle}`;
if (deduped.has(key)) {
return null;
}
deduped.add(key);
}
if (!validateConnection) {
const sourceType = resolveConnectorSourceFieldType(connectorId, nodes, edges, templates);
if (!sourceType) {
return null;
}
const inputEdgeId = getConnectorInputEdge(connectorId, edges)?.id;
const outputEdgeIds = new Set(outputEdges.map((edge) => edge.id));
for (const connection of spliceConnections) {
const targetNode = nodes.find((node) => node.id === connection.target);
if (!targetNode || !isInvocationNode(targetNode)) {
return null;
}
const targetTemplate = templates[targetNode.data.type];
const targetFieldTemplate = targetTemplate?.inputs[connection.targetHandle];
if (!targetFieldTemplate) {
return null;
}
const matchesExistingDirectEdge = edges.some(
(edge) =>
edge.type === 'default' &&
edge.source === connection.source &&
edge.sourceHandle === connection.sourceHandle &&
edge.target === connection.target &&
edge.targetHandle === connection.targetHandle
);
if (matchesExistingDirectEdge) {
return null;
}
const targetConflictCount = spliceConnections.filter(
(candidate) => candidate.target === connection.target && candidate.targetHandle === connection.targetHandle
).length;
const existingTargetConflict = edges.some(
(edge) =>
edge.type === 'default' &&
edge.id !== inputEdgeId &&
!outputEdgeIds.has(edge.id) &&
edge.target === connection.target &&
edge.targetHandle === connection.targetHandle
);
if (
targetFieldTemplate.type.name !== 'CollectionItemField' &&
(targetConflictCount > 1 || existingTargetConflict)
) {
return null;
}
if (
sourceType.name !== targetFieldTemplate.type.name &&
targetFieldTemplate.type.name !== 'CollectionItemField'
) {
return null;
}
}
return spliceConnections;
}
const ignoredEdgeIds = new Set([
getConnectorInputEdge(connectorId, edges)?.id,
...outputEdges.map((edge) => edge.id),
]);
const existingEdges = edges.filter((edge) => !ignoredEdgeIds.has(edge.id));
const stagedConnections: SpliceConnection[] = [];
for (const connection of spliceConnections) {
const stagedEdges = [
...existingEdges,
...stagedConnections.map(
({ source, sourceHandle, target, targetHandle }) =>
({
id: `splice-${source}-${sourceHandle}-${target}-${targetHandle}`,
type: 'default',
source,
sourceHandle,
target,
targetHandle,
}) satisfies AnyEdge
),
];
if (validateConnection(connection, nodes, stagedEdges, templates, null, true) !== null) {
return null;
}
stagedConnections.push(connection);
}
return spliceConnections;
};

View File

@@ -1,13 +1,26 @@
import { deepClone } from 'common/util/deepClone';
import { unset } from 'es-toolkit/compat';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
import {
getFirstValidConnection,
getSourceCandidateFields,
getTargetCandidateFields,
} from 'features/nodes/store/util/getFirstValidConnection';
import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils';
import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils';
import { describe, expect, it } from 'vitest';
const buildConnectorNode = (id: string) => ({
id,
type: 'connector' as const,
position: { x: 0, y: 0 },
data: {
id,
type: 'connector' as const,
label: 'Connector',
isOpen: true,
},
});
describe('getFirstValidConnection', () => {
it('should return null if the pending and candidate nodes are the same node', () => {
const n = buildNode(add);
@@ -120,6 +133,33 @@ describe('getFirstValidConnection', () => {
expect(r).toEqual(null);
});
});
it('should resolve connector target candidates when connecting an invocation output to a connector', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
expect(getFirstValidConnection(n1.id, 'value', connector.id, null, [n1, connector], [], templates, null)).toEqual({
source: n1.id,
sourceHandle: 'value',
target: connector.id,
targetHandle: CONNECTOR_INPUT_HANDLE,
});
});
it('should resolve connector source candidates when connecting a connector to a typed invocation input', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(img_resize);
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
expect(
getFirstValidConnection(connector.id, null, n2.id, 'width', [n1, connector, n2], edges, templates, null)
).toEqual({
source: connector.id,
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
target: n2.id,
targetHandle: 'width',
});
});
});
describe('getTargetCandidateFields', () => {
@@ -160,6 +200,62 @@ describe('getTargetCandidateFields', () => {
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
it('should return the connector input handle when the target is a connector', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const r = getTargetCandidateFields(n1.id, 'value', connector.id, [n1, connector], [], templates, null);
expect(r.map((field) => field.name)).toEqual([CONNECTOR_INPUT_HANDLE]);
});
it('should advertise typed target candidates for an unresolved connector output when no downstream constraint exists', () => {
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(sub);
const r = getTargetCandidateFields(
connector.id,
CONNECTOR_OUTPUT_HANDLE,
n2.id,
[connector, n2],
[],
templates,
null
);
expect(r.map((field) => field.name)).toEqual(['a', 'b']);
});
it('should only advertise compatible typed target candidates for an unresolved connector output with downstream constraints', () => {
const connector = buildConnectorNode('connector-1');
const n1 = buildNode(sub);
const n2 = buildNode(img_resize);
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')];
const r = getTargetCandidateFields(
connector.id,
CONNECTOR_OUTPUT_HANDLE,
n2.id,
[connector, n1, n2],
edges,
templates,
null
);
expect(r.map((field) => field.name)).toEqual(['width', 'height']);
});
it('should resolve chained connector sources like the direct upstream source', () => {
const n1 = buildNode(add);
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const n2 = buildNode(img_resize);
const edges = [
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
];
const r = getTargetCandidateFields(
connectorB.id,
CONNECTOR_OUTPUT_HANDLE,
n2.id,
[n1, connectorA, connectorB, n2],
edges,
templates,
null
);
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
});
});
describe('getSourceCandidateFields', () => {
@@ -200,4 +296,18 @@ describe('getSourceCandidateFields', () => {
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate);
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
});
it('should return the connector output handle when the source is a connector with a typed upstream source', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(img_resize);
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
const r = getSourceCandidateFields(n2.id, 'width', connector.id, [n1, connector, n2], edges, templates, null);
expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]);
});
it('should return a target-constrained connector source candidate when the connector chain is unresolved', () => {
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(img_resize);
const r = getSourceCandidateFields(n2.id, 'width', connector.id, [connector, n2], [], templates, null);
expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]);
});
});

View File

@@ -1,9 +1,15 @@
import type { Connection } from '@xyflow/react';
import { map } from 'es-toolkit/compat';
import type { Templates } from 'features/nodes/store/types';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
resolveConnectorSourceFieldType,
} from 'features/nodes/store/util/connectorTopology';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import { isConnectorNode } from 'features/nodes/types/invocation';
/**
*
@@ -91,16 +97,43 @@ export const getTargetCandidateFields = (
return [];
}
const sourceTemplate = templates[sourceNode.data.type];
if (isConnectorNode(targetNode)) {
const candidate = {
name: CONNECTOR_INPUT_HANDLE,
title: 'Connector Input',
description: '',
fieldKind: 'input',
input: 'connection',
required: false,
default: undefined,
ui_hidden: false,
type: {
name: 'AnyField',
cardinality: 'SINGLE',
batch: false,
},
} satisfies FieldInputTemplate;
const c = { source, sourceHandle, target, targetHandle: candidate.name };
return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : [];
}
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
if (!targetTemplate) {
return [];
}
const sourceField = sourceTemplate.outputs[sourceHandle];
if (!isConnectorNode(sourceNode)) {
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return [];
}
if (!sourceField) {
return [];
const sourceField = sourceTemplate.outputs[sourceHandle];
if (!sourceField) {
return [];
}
}
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
@@ -127,15 +160,45 @@ export const getSourceCandidateFields = (
return [];
}
if (isConnectorNode(sourceNode)) {
const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates);
const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null;
const targetFieldType = targetTemplate?.inputs[targetHandle]?.type;
const candidateType = sourceFieldType ?? targetFieldType;
if (!candidateType) {
return [];
}
const candidate = {
name: CONNECTOR_OUTPUT_HANDLE,
title: 'Connector Output',
description: '',
fieldKind: 'output',
ui_hidden: false,
type: candidateType,
} satisfies FieldOutputTemplate;
const c = { source, sourceHandle: candidate.name, target, targetHandle };
return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : [];
}
const sourceTemplate = templates[sourceNode.data.type];
const targetTemplate = templates[targetNode.data.type];
if (!sourceTemplate || !targetTemplate) {
if (!sourceTemplate) {
return [];
}
const targetField = targetTemplate.inputs[targetHandle];
if (!isConnectorNode(targetNode)) {
const targetTemplate = templates[targetNode.data.type];
if (!targetTemplate) {
return [];
}
if (!targetField) {
const targetField = targetTemplate.inputs[targetHandle];
if (!targetField) {
return [];
}
} else if (targetHandle !== CONNECTOR_INPUT_HANDLE) {
return [];
}

View File

@@ -0,0 +1,23 @@
import { describe, expect, it } from 'vitest';
import { connectionToEdge } from './reactFlowUtil';
describe('connectionToEdge', () => {
it('creates a default edge with the expected id and endpoints', () => {
expect(
connectionToEdge({
source: 'source-node',
sourceHandle: 'value',
target: 'target-node',
targetHandle: 'a',
})
).toEqual({
type: 'default',
source: 'source-node',
sourceHandle: 'value',
target: 'target-node',
targetHandle: 'a',
id: 'reactflow__edge-source-nodevalue-target-nodea',
});
});
});

View File

@@ -24,6 +24,7 @@ export const connectionToEdge = (connection: Connection): AnyEdge => {
const { source, sourceHandle, target, targetHandle } = connection;
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
return {
type: 'default',
source,
sourceHandle,
target,

View File

@@ -3,6 +3,11 @@ import { set } from 'es-toolkit/compat';
import type { InvocationTemplate } from 'features/nodes/types/invocation';
import { describe, expect, it } from 'vitest';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
getConnectorDeletionSpliceConnections,
} from './connectorTopology';
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
import { validateConnection } from './validateConnection';
@@ -139,6 +144,18 @@ const integerCollectionOutputTemplate: InvocationTemplate = {
classification: 'stable',
};
const buildConnectorNode = (id: string) => ({
id,
type: 'connector' as const,
position: { x: 0, y: 0 },
data: {
id,
type: 'connector' as const,
label: 'Connector',
isOpen: true,
},
});
describe(validateConnection.name, () => {
it('should reject invalid connection to self', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
@@ -458,6 +475,214 @@ describe(validateConnection.name, () => {
expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
describe('connectors', () => {
it('should accept invocation output to connector input', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const r = validateConnection(
{ source: n1.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
[n1, connector],
[],
templates,
null
);
expect(r).toEqual(null);
});
it('should reject a second input into a connector', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
const r = validateConnection(
{ source: n2.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
[n1, n2, connector],
edges,
templates,
null
);
expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection');
});
it('should accept connector output to invocation input when the upstream type matches', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(sub);
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
const r = validateConnection(
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
[n1, connector, n2],
edges,
templates,
null
);
expect(r).toEqual(null);
});
it('should reject connector output to invocation input when the upstream type mismatches', () => {
const n1 = buildNode(add);
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(img_resize);
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
const r = validateConnection(
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' },
[n1, connector, n2],
edges,
templates,
null
);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should accept unresolved connector output to a typed invocation input as the first downstream constraint', () => {
const connector = buildConnectorNode('connector-1');
const n2 = buildNode(sub);
const r = validateConnection(
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
[connector, n2],
[],
templates,
null
);
expect(r).toEqual(null);
});
it('should reject unresolved connector output when it conflicts with an existing downstream typed constraint', () => {
const connector = buildConnectorNode('connector-1');
const n1 = buildNode(sub);
const n2 = buildNode(img_resize);
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')];
const r = validateConnection(
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' },
[connector, n1, n2],
edges,
templates,
null
);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should reject connecting an incompatible upstream source into a connector with downstream typed constraints', () => {
const source = buildNode(main_model_loader);
const connector = buildConnectorNode('connector-1');
const target = buildNode(sub);
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')];
const r = validateConnection(
{ source: source.id, sourceHandle: 'vae', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
[source, connector, target],
edges,
templates,
null
);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should preserve type information through chained connectors', () => {
const n1 = buildNode(add);
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const n2 = buildNode(sub);
const edges = [
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
];
const r = validateConnection(
{ source: connectorB.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
[n1, connectorA, connectorB, n2],
edges,
templates,
null
);
expect(r).toEqual(null);
});
it('should reject cycles routed through connectors', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const edges = [
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
];
const r = validateConnection(
{ source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' },
[n1, n2, connector],
edges,
templates,
null
);
expect(r).toEqual('nodes.connectionWouldCreateCycle');
});
it('should preserve collect item validation through connectors', () => {
const n1 = buildNode(add);
const n2 = buildNode(collect);
const n3 = buildNode(main_model_loader);
const connector = buildConnectorNode('connector-1');
const edges = [
buildEdge(n1.id, 'value', n2.id, 'item'),
buildEdge(n3.id, 'vae', connector.id, CONNECTOR_INPUT_HANDLE),
];
const r = validateConnection(
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'item' },
[n1, n2, n3, connector],
edges,
templates,
null
);
expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes');
});
it('should preserve if branch validation through connectors', () => {
const n1 = buildNode(add);
const n2 = buildNode(img_resize);
const n3 = buildNode(ifTemplate);
const connector = buildConnectorNode('connector-1');
const edges = [
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n3.id, 'true_input'),
];
const r = validateConnection(
{ source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' },
[n1, n2, n3, connector],
edges,
{ ...templates, if: ifTemplate },
null
);
expect(r).toEqual('nodes.fieldTypesMustMatch');
});
it('should reject connector deletion splice-through when it would duplicate an existing direct edge', () => {
const n1 = buildNode(add);
const n2 = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const edges = [
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
buildEdge(n1.id, 'value', n2.id, 'a'),
];
expect(getConnectorDeletionSpliceConnections(connector.id, [n1, n2, connector], edges, templates)).toBe(null);
});
it('should reject connector deletion splice-through when fan-out would violate a single-input target', () => {
const n1 = buildNode(add);
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const n2 = buildNode(sub);
const edges = [
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
];
expect(
getConnectorDeletionSpliceConnections(connectorA.id, [n1, connectorA, connectorB, n2], edges, templates)
).toBe(null);
});
});
describe('non-strict mode', () => {
it('should reject connections from self to self in non-strict mode', () => {
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };

View File

@@ -1,10 +1,17 @@
import type { Connection as NullableConnection } from '@xyflow/react';
import type { Templates } from 'features/nodes/store/types';
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
import {
CONNECTOR_INPUT_HANDLE,
CONNECTOR_OUTPUT_HANDLE,
resolveConnectorSource,
} from 'features/nodes/store/util/connectorTopology';
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
import type { FieldType } from 'features/nodes/types/field';
import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation';
import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { SetNonNullable } from 'type-fest';
type Connection = SetNonNullable<NullableConnection>;
@@ -18,6 +25,12 @@ type ValidateConnectionFunc = (
strict?: boolean
) => string | null;
type EffectiveSource = {
node: InvocationNode;
handle: string;
fieldTemplate: NonNullable<Templates[string]>['outputs'][string];
};
const getEqualityPredicate =
(c: Connection) =>
(e: AnyEdge): boolean => {
@@ -41,10 +54,7 @@ const isIfInputHandle = (handle: string): handle is (typeof IF_INPUT_HANDLES)[nu
return IF_INPUT_HANDLES.includes(handle as (typeof IF_INPUT_HANDLES)[number]);
};
const isSingleCollectionPairOfSameBaseType = (
firstType: { name: string; cardinality: string; batch: boolean },
secondType: { name: string; cardinality: string; batch: boolean }
) => {
const isSingleCollectionPairOfSameBaseType = (firstType: FieldType, secondType: FieldType) => {
const isSingleToCollection =
firstType.cardinality === 'SINGLE' && secondType.cardinality === 'COLLECTION' && firstType.name === secondType.name;
const isCollectionToSingle =
@@ -52,6 +62,139 @@ const isSingleCollectionPairOfSameBaseType = (
return firstType.batch === secondType.batch && (isSingleToCollection || isCollectionToSingle);
};
const areFieldTypesCompatible = (firstType: FieldType, secondType: FieldType) =>
validateConnectionTypes(firstType, secondType) ||
validateConnectionTypes(secondType, firstType) ||
isSingleCollectionPairOfSameBaseType(firstType, secondType);
type ConnectorTerminalTargetEdge = AnyEdge & {
type: 'default';
sourceHandle: string;
targetHandle: string;
};
const getConnectorTerminalTargetEdges = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => {
const visited = new Set<string>();
const resolve = (currentConnectorId: string): ConnectorTerminalTargetEdge[] => {
if (visited.has(currentConnectorId)) {
return [];
}
visited.add(currentConnectorId);
return edges.flatMap((edge) => {
if (
edge.type !== 'default' ||
edge.source !== currentConnectorId ||
edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE ||
typeof edge.targetHandle !== 'string'
) {
return [];
}
const targetNode = nodes.find((node) => node.id === edge.target);
if (targetNode && isConnectorNode(targetNode)) {
return resolve(targetNode.id);
}
return [edge as ConnectorTerminalTargetEdge];
});
};
return resolve(connectorId);
};
const getConnectorSubgraphEdgeIds = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => {
const visited = new Set<string>();
const edgeIds = new Set<string>();
const visit = (currentConnectorId: string) => {
if (visited.has(currentConnectorId)) {
return;
}
visited.add(currentConnectorId);
edges.forEach((edge) => {
if (
edge.type !== 'default' ||
edge.source !== currentConnectorId ||
edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE ||
typeof edge.targetHandle !== 'string'
) {
return;
}
edgeIds.add(edge.id);
const targetNode = nodes.find((node) => node.id === edge.target);
if (targetNode && isConnectorNode(targetNode)) {
visit(targetNode.id);
}
});
};
visit(connectorId);
return edgeIds;
};
const getEffectiveSource = (
sourceId: string,
sourceHandle: string,
nodes: AnyNode[],
edges: AnyEdge[],
templates: Templates
): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => {
const sourceNode = nodes.find((n) => n.id === sourceId);
if (!sourceNode) {
return 'nodes.missingNode';
}
if (isConnectorNode(sourceNode)) {
if (sourceHandle !== CONNECTOR_OUTPUT_HANDLE) {
return 'nodes.missingFieldTemplate';
}
const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges);
if (!resolvedSource) {
return null;
}
return getEffectiveSource(resolvedSource.nodeId, resolvedSource.fieldName, nodes, edges, templates);
}
if (!isInvocationNode(sourceNode)) {
return 'nodes.missingInvocationTemplate';
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
return 'nodes.missingInvocationTemplate';
}
const sourceFieldTemplate = sourceTemplate.outputs[sourceHandle];
if (!sourceFieldTemplate) {
return 'nodes.missingFieldTemplate';
}
return {
node: sourceNode,
handle: sourceHandle,
fieldTemplate: sourceFieldTemplate,
};
};
const getEffectiveSourceForEdge = (
edge: AnyEdge,
nodes: AnyNode[],
edges: AnyEdge[],
templates: Templates
): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => {
if (edge.type !== 'default' || typeof edge.sourceHandle !== 'string') {
return null;
}
return getEffectiveSource(edge.source, edge.sourceHandle, nodes, edges, templates);
};
/**
* Validates a connection between two fields
* @returns A translation key for an error if the connection is invalid, otherwise null
@@ -83,18 +226,63 @@ export const validateConnection: ValidateConnectionFunc = (
return 'nodes.cannotDuplicateConnection';
}
const sourceNode = nodes.find((n) => n.id === c.source);
if (!sourceNode) {
return 'nodes.missingNode';
}
const targetNode = nodes.find((n) => n.id === c.target);
const sourceNode = nodes.find((n) => n.id === c.source);
if (!targetNode) {
return 'nodes.missingNode';
}
const sourceTemplate = templates[sourceNode.data.type];
if (!sourceTemplate) {
const effectiveSource = getEffectiveSource(c.source, c.sourceHandle, nodes, filteredEdges, templates);
if (effectiveSource === 'nodes.missingNode') {
return 'nodes.missingNode';
}
if (effectiveSource === 'nodes.missingInvocationTemplate') {
return 'nodes.missingInvocationTemplate';
}
if (effectiveSource === 'nodes.missingFieldTemplate') {
return 'nodes.missingFieldTemplate';
}
if (isConnectorNode(targetNode)) {
if (c.targetHandle !== CONNECTOR_INPUT_HANDLE) {
return 'nodes.missingFieldTemplate';
}
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
return 'nodes.inputMayOnlyHaveOneConnection';
}
if (effectiveSource) {
const connectorSubgraphEdgeIds = getConnectorSubgraphEdgeIds(targetNode.id, nodes, filteredEdges);
const stagedEdges = filteredEdges.filter((edge) => !connectorSubgraphEdgeIds.has(edge.id));
const terminalTargetEdges = getConnectorTerminalTargetEdges(targetNode.id, nodes, filteredEdges);
for (const terminalTargetEdge of terminalTargetEdges) {
const downstreamValidation = validateConnection(
{
source: c.source,
sourceHandle: c.sourceHandle,
target: terminalTargetEdge.target,
targetHandle: terminalTargetEdge.targetHandle,
},
nodes,
stagedEdges,
templates,
null,
true
);
if (downstreamValidation !== null) {
return downstreamValidation;
}
}
}
// Unresolved connector chains are allowed to terminate on another connector.
return null;
}
if (!isInvocationNode(targetNode)) {
return 'nodes.missingInvocationTemplate';
}
@@ -103,11 +291,6 @@ export const validateConnection: ValidateConnectionFunc = (
return 'nodes.missingInvocationTemplate';
}
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
if (!sourceFieldTemplate) {
return 'nodes.missingFieldTemplate';
}
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
if (!targetFieldTemplate) {
return 'nodes.missingFieldTemplate';
@@ -117,23 +300,58 @@ export const validateConnection: ValidateConnectionFunc = (
return 'nodes.cannotConnectToDirectInput';
}
if (!effectiveSource) {
if (sourceNode && isConnectorNode(sourceNode) && c.sourceHandle === CONNECTOR_OUTPUT_HANDLE) {
const existingTerminalTargetEdges = getConnectorTerminalTargetEdges(sourceNode.id, nodes, filteredEdges).filter(
(edge) => !(edge.target === c.target && edge.targetHandle === c.targetHandle)
);
for (const terminalTargetEdge of existingTerminalTargetEdges) {
const constrainedTargetNode = nodes.find((node) => node.id === terminalTargetEdge.target);
if (!constrainedTargetNode || !isInvocationNode(constrainedTargetNode)) {
return 'nodes.missingInvocationTemplate';
}
const constrainedTargetTemplate = templates[constrainedTargetNode.data.type];
if (!constrainedTargetTemplate) {
return 'nodes.missingInvocationTemplate';
}
const constrainedTargetFieldTemplate = constrainedTargetTemplate.inputs[terminalTargetEdge.targetHandle];
if (!constrainedTargetFieldTemplate) {
return 'nodes.missingFieldTemplate';
}
if (!areFieldTypesCompatible(constrainedTargetFieldTemplate.type, targetFieldTemplate.type)) {
return 'nodes.fieldTypesMustMatch';
}
}
return null;
}
return 'nodes.fieldTypesMustMatch';
}
const { node: resolvedSourceNode, handle: sourceHandle, fieldTemplate: sourceFieldTemplate } = effectiveSource;
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
// Collect nodes shouldn't mix and match field types.
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
const collectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id);
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
return 'nodes.cannotMixAndMatchCollectionItemTypes';
}
}
if (
sourceNode.data.type === 'collect' &&
c.sourceHandle === 'collection' &&
resolvedSourceNode.data.type === 'collect' &&
sourceHandle === 'collection' &&
targetNode.data.type === 'collect' &&
c.targetHandle === 'collection'
) {
// Chained collect nodes should preserve a single item type when both ends are already typed.
const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id);
const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
const sourceCollectItemType = getCollectItemType(templates, nodes, filteredEdges, resolvedSourceNode.id);
const targetCollectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id);
if (
sourceCollectItemType &&
targetCollectItemType &&
@@ -148,33 +366,24 @@ export const validateConnection: ValidateConnectionFunc = (
const siblingInputEdge = filteredEdges.find((e) => e.target === c.target && e.targetHandle === siblingHandle);
if (siblingInputEdge) {
if (siblingInputEdge.source === null || siblingInputEdge.source === undefined) {
const siblingEffectiveSource = getEffectiveSourceForEdge(siblingInputEdge, nodes, filteredEdges, templates);
if (siblingEffectiveSource === 'nodes.missingNode') {
return 'nodes.missingNode';
}
if (siblingInputEdge.sourceHandle === null || siblingInputEdge.sourceHandle === undefined) {
return 'nodes.missingFieldTemplate';
}
const siblingSourceNode = nodes.find((n) => n.id === siblingInputEdge.source);
if (!siblingSourceNode) {
return 'nodes.missingNode';
}
const siblingSourceTemplate = templates[siblingSourceNode.data.type];
if (!siblingSourceTemplate) {
if (siblingEffectiveSource === 'nodes.missingInvocationTemplate') {
return 'nodes.missingInvocationTemplate';
}
const siblingSourceFieldTemplate = siblingSourceTemplate.outputs[siblingInputEdge.sourceHandle];
if (!siblingSourceFieldTemplate) {
if (siblingEffectiveSource === 'nodes.missingFieldTemplate') {
return 'nodes.missingFieldTemplate';
}
if (!siblingEffectiveSource) {
return 'nodes.fieldTypesMustMatch';
}
const areIfInputTypesCompatible =
validateConnectionTypes(sourceFieldTemplate.type, siblingSourceFieldTemplate.type) ||
validateConnectionTypes(siblingSourceFieldTemplate.type, sourceFieldTemplate.type) ||
isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingSourceFieldTemplate.type);
validateConnectionTypes(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type) ||
validateConnectionTypes(siblingEffectiveSource.fieldTemplate.type, sourceFieldTemplate.type) ||
isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type);
if (!areIfInputTypesCompatible) {
return 'nodes.fieldTypesMustMatch';
@@ -189,30 +398,17 @@ export const validateConnection: ValidateConnectionFunc = (
}
}
if (sourceNode.data.type === 'if' && c.sourceHandle === 'value') {
if (resolvedSourceNode.data.type === 'if' && sourceHandle === 'value') {
const ifInputEdges = filteredEdges.filter(
(e) => e.target === sourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle)
(e) =>
e.target === resolvedSourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle)
);
const ifInputTypes = ifInputEdges.flatMap((edge) => {
if (edge.source === null || edge.source === undefined) {
const ifInputSource = getEffectiveSourceForEdge(edge, nodes, filteredEdges, templates);
if (!ifInputSource || typeof ifInputSource === 'string') {
return [];
}
if (edge.sourceHandle === null || edge.sourceHandle === undefined) {
return [];
}
const ifInputSourceNode = nodes.find((n) => n.id === edge.source);
if (!ifInputSourceNode) {
return [];
}
const ifInputSourceTemplate = templates[ifInputSourceNode.data.type];
if (!ifInputSourceTemplate) {
return [];
}
const ifInputSourceFieldTemplate = ifInputSourceTemplate.outputs[edge.sourceHandle];
if (!ifInputSourceFieldTemplate) {
return [];
}
return [ifInputSourceFieldTemplate.type];
return [ifInputSource.fieldTemplate.type];
});
if (ifInputTypes.length > 0) {

View File

@@ -43,6 +43,12 @@ export const zNotesNodeData = z.object({
isOpen: z.boolean(),
notes: z.string(),
});
export const zConnectorNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('connector'),
label: z.string(),
isOpen: z.boolean(),
});
const zCurrentImageNodeData = z.object({
id: z.string().trim().min(1),
type: z.literal('current_image'),
@@ -52,6 +58,7 @@ const zCurrentImageNodeData = z.object({
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
export type ConnectorNodeData = z.infer<typeof zConnectorNodeData>;
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
const zInvocationNodeValidationSchema = z.looseObject({
@@ -70,6 +77,15 @@ const zNotesNodeValidationSchema = z.looseObject({
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
export type NotesNode = z.infer<typeof zNotesNode>;
const zConnectorNodeValidationSchema = z.looseObject({
type: z.literal('connector'),
data: zConnectorNodeData,
});
const zConnectorNode = z.custom<Node<ConnectorNodeData, 'connector'>>(
(val) => zConnectorNodeValidationSchema.safeParse(val).success
);
export type ConnectorNode = z.infer<typeof zConnectorNode>;
const zCurrentImageNodeValidationSchema = z.looseObject({
type: z.literal('current_image'),
data: zCurrentImageNodeData,
@@ -79,12 +95,14 @@ const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
);
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zConnectorNode, zCurrentImageNode]);
export type AnyNode = z.infer<typeof zAnyNode>;
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
Boolean(node && node.type === 'invocation');
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
export const isConnectorNode = (node?: AnyNode | null): node is ConnectorNode =>
Boolean(node && node.type === 'connector');
// #endregion
// #region NodeExecutionState

View File

@@ -3,7 +3,7 @@ import { z } from 'zod';
import type { FieldType } from './field';
import { zFieldIdentifier } from './field';
import { zInvocationNodeData, zNotesNodeData } from './invocation';
import { zConnectorNodeData, zInvocationNodeData, zNotesNodeData } from './invocation';
// #region Workflow misc
const zXYPosition = z
@@ -31,7 +31,13 @@ const zWorkflowNotesNode = z.object({
data: zNotesNodeData,
position: zXYPosition,
});
const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
const zWorkflowConnectorNode = z.object({
id: z.string().trim().min(1),
type: z.literal('connector'),
data: zConnectorNodeData,
position: zXYPosition,
});
const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode, zWorkflowConnectorNode]);
type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
@@ -377,7 +383,7 @@ export const zWorkflowV3 = z.object({
exposedFields: z.array(zFieldIdentifier),
meta: z.object({
category: zWorkflowCategory.default('user'),
version: z.literal('3.0.0'),
version: z.literal('4.0.0'),
}),
// Use the validated form schema!
form: zValidatedBuilderForm,

View File

@@ -0,0 +1,194 @@
import { deepClone } from 'common/util/deepClone';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils';
import { describe, expect, it } from 'vitest';
import { buildNodesGraph } from './buildNodesGraph';
const buildConnectorNode = (id: string) => ({
id,
type: 'connector' as const,
position: { x: 0, y: 0 },
data: {
id,
type: 'connector' as const,
label: 'Connector',
isOpen: true,
},
});
const buildState = (nodes: unknown[], edges: unknown[]) =>
({
nodes: {
present: {
_version: 1,
nodes,
edges,
formFieldInitialValues: {},
id: undefined,
name: '',
author: '',
description: '',
version: '',
contact: '',
tags: '',
notes: '',
exposedFields: [],
meta: { version: '4.0.0', category: 'user' },
form: {
rootElementId: 'root',
elements: {
root: {
id: 'root',
type: 'container',
data: { layout: 'column', children: [] },
},
},
},
},
},
gallery: {
autoAddBoardId: 'none',
selection: [],
},
}) as unknown as Parameters<typeof buildNodesGraph>[0];
describe('buildNodesGraph', () => {
it('flattens a single connector to one direct execution edge', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const state = buildState(
[source, target, connector],
[
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
]
);
const graph = buildNodesGraph(state, templates);
expect(graph.nodes).not.toHaveProperty(connector.id);
expect(graph.edges).toEqual([
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: target.id, field: 'a' },
},
]);
});
it('flattens chained connectors transitively', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connectorA = buildConnectorNode('connector-a');
const connectorB = buildConnectorNode('connector-b');
const state = buildState(
[source, target, connectorA, connectorB],
[
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
]
);
const graph = buildNodesGraph(state, templates);
expect(graph.edges).toEqual([
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: target.id, field: 'a' },
},
]);
});
it('fans out through a connector into multiple execution edges', () => {
const source = buildNode(add);
const targetA = buildNode(sub);
const targetB = buildNode(img_resize);
const connector = buildConnectorNode('connector-1');
const state = buildState(
[source, targetA, targetB, connector],
[
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'),
]
);
const graph = buildNodesGraph(state, templates);
expect(graph.edges).toEqual([
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: targetA.id, field: 'a' },
},
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: targetB.id, field: 'width' },
},
]);
});
it('drops unresolved connector paths from the execution graph', () => {
const target = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const state = buildState([target, connector], [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]);
const graph = buildNodesGraph(state, templates);
expect(graph.nodes).not.toHaveProperty(connector.id);
expect(graph.edges).toEqual([]);
});
it('deduplicates effective execution edges created by flattening', () => {
const source = buildNode(add);
const target = buildNode(sub);
const connector = buildConnectorNode('connector-1');
const state = buildState(
[source, target, connector],
[
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
buildEdge(source.id, 'value', target.id, 'a'),
]
);
const graph = buildNodesGraph(state, templates);
expect(graph.edges).toEqual([
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: target.id, field: 'a' },
},
]);
});
it('still omits explicit destination input values when the flattened edge exists', () => {
const source = buildNode(add);
const target = deepClone(buildNode(sub));
const connector = buildConnectorNode('connector-1');
const inputA = target.data.inputs.a;
expect(inputA).toBeDefined();
if (!inputA) {
throw new Error('Missing input a');
}
inputA.value = 'not-an-integer' as never;
const state = buildState(
[source, target, connector],
[
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
]
);
const graph = buildNodesGraph(state, templates);
expect(graph.edges).toEqual([
{
source: { node_id: source.id, field: 'value' },
destination: { node_id: target.id, field: 'a' },
},
]);
expect(graph.nodes[target.id]).not.toHaveProperty('a');
});
});

View File

@@ -4,10 +4,11 @@ import { omit, reduce } from 'es-toolkit/compat';
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopology';
import type { BoardField } from 'features/nodes/types/common';
import type { BoardFieldInputInstance } from 'features/nodes/types/field';
import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field';
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@@ -96,12 +97,57 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require
const filteredNodeIds = filteredNodes.map(({ id }) => id);
// skip out the "dummy" edges between collapsed nodes
const filteredEdges = edges
.filter((edge) => edge.type !== 'collapsed')
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
const flattenedEdges = edges
.filter((edge) => edge.type === 'default')
.flatMap((edge) => {
const targetNode = nodes.find((node) => node.id === edge.target);
if (!targetNode || !isInvocationNode(targetNode) || !isExecutableNode(targetNode)) {
return [];
}
const sourceNode = nodes.find((node) => node.id === edge.source);
if (!sourceNode) {
return [];
}
if (isInvocationNode(sourceNode)) {
if (!isExecutableNode(sourceNode) || !filteredNodeIds.includes(sourceNode.id)) {
return [];
}
return [edge];
}
if (isConnectorNode(sourceNode)) {
const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges);
if (!resolvedSource || !filteredNodeIds.includes(resolvedSource.nodeId)) {
return [];
}
return [
{
...edge,
id: `flattened-${resolvedSource.nodeId}-${resolvedSource.fieldName}-${edge.target}-${edge.targetHandle}`,
source: resolvedSource.nodeId,
sourceHandle: resolvedSource.fieldName,
},
];
}
return [];
})
.filter((edge, index, allEdges) => {
return (
allEdges.findIndex(
(candidate) =>
candidate.source === edge.source &&
candidate.sourceHandle === edge.sourceHandle &&
candidate.target === edge.target &&
candidate.targetHandle === edge.targetHandle
) === index
);
});
// Reduce the node editor edges into invocation graph edges
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const parsedEdges = flattenedEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
const { source, target, sourceHandle, targetHandle } = edge;
if (!sourceHandle || !targetHandle) {

View File

@@ -0,0 +1,21 @@
import type { XYPosition } from '@xyflow/react';
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
import type { ConnectorNode } from 'features/nodes/types/invocation';
import { v4 as uuidv4 } from 'uuid';
export const buildConnectorNode = (position: XYPosition): ConnectorNode => {
const nodeId = uuidv4();
const node: ConnectorNode = {
...SHARED_NODE_PROPERTIES,
id: nodeId,
type: 'connector',
position,
data: {
id: nodeId,
type: 'connector',
isOpen: true,
label: 'Connector',
},
};
return node;
};

View File

@@ -5,7 +5,7 @@ import { parseify } from 'common/util/serialize';
import { pick } from 'es-toolkit/compat';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { zWorkflowV3 } from 'features/nodes/types/workflow';
import i18n from 'i18n';
@@ -42,6 +42,9 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => {
if (isInvocationNode(node) && node.type) {
const { id, type, data, position } = node;
newWorkflow.nodes.push({ id, type, data, position });
} else if (isConnectorNode(node) && node.type) {
const { id, type, data, position } = node;
newWorkflow.nodes.push({ id, type, data, position });
} else if (isNotesNode(node) && node.type) {
const { id, type, data, position } = node;
newWorkflow.nodes.push({ id, type, data, position });

View File

@@ -30,7 +30,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor
description: '',
meta: {
category: 'user',
version: '3.0.0',
version: '4.0.0',
},
notes: '',
tags: '',

View File

@@ -70,9 +70,11 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
};
const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
// Bump version
(workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0';
// Parsing strips out any extra properties not in the latest version
return migrateV3toV4(workflowToMigrate as unknown as WorkflowV3);
};
const migrateV3toV4 = (workflowToMigrate: WorkflowV3): WorkflowV3 => {
workflowToMigrate.meta.version = '4.0.0';
return zWorkflowV3.parse(workflowToMigrate);
};
@@ -100,6 +102,10 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
workflow = migrateV2toV3(v2);
}
if (get(workflow, 'meta.version') === '3.0.0') {
workflow = migrateV3toV4(workflow as WorkflowV3);
}
// We should now have a V3 workflow
const migratedWorkflow = zWorkflowV3.parse(workflow);

View File

@@ -1,4 +1,5 @@
import { get } from 'es-toolkit/compat';
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils';
import type { WorkflowV3 } from 'features/nodes/types/workflow';
import { getDefaultForm } from 'features/nodes/types/workflow';
@@ -7,6 +8,17 @@ import { describe, expect, it } from 'vitest';
//TODO(psyche): Test workflow validation for form builder fields
describe('validateWorkflow', () => {
const buildConnectorNode = (id: string) => ({
id,
type: 'connector' as const,
position: { x: 0, y: 0 },
data: {
id,
type: 'connector' as const,
label: 'Connector',
isOpen: true,
},
});
const getWorkflow = (): WorkflowV3 => ({
name: '',
author: '',
@@ -17,7 +29,7 @@ describe('validateWorkflow', () => {
notes: '',
exposedFields: [],
form: getDefaultForm(),
meta: { version: '3.0.0', category: 'user' },
meta: { version: '4.0.0', category: 'user' },
nodes: [
{
id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad',
@@ -104,6 +116,7 @@ describe('validateWorkflow', () => {
});
expect(validationResult.warnings.length).toBe(1);
expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined();
expect(validationResult.workflow.meta.version).toBe('4.0.0');
});
it('should reset boards that are inaccessible', async () => {
const validationResult = await validateWorkflow({
@@ -127,4 +140,138 @@ describe('validateWorkflow', () => {
expect(validationResult.warnings.length).toBe(1);
expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined();
});
it('should delete malformed connector edges with invalid handles', async () => {
const workflow = getWorkflow();
workflow.nodes.push(buildConnectorNode('connector-1'));
workflow.edges.push({
id: 'e1',
type: 'default',
source: workflow.nodes[0]!.id,
sourceHandle: 'vae',
target: 'connector-1',
targetHandle: 'wrong',
});
const validationResult = await validateWorkflow({
workflow,
templates: { img_resize, main_model_loader },
checkImageAccess: resolveTrue,
checkBoardAccess: resolveTrue,
checkModelAccess: resolveTrue,
});
expect(validationResult.workflow.edges).toEqual([]);
expect(validationResult.warnings.length).toBe(1);
});
it('should delete connector edges with missing endpoints', async () => {
const workflow = getWorkflow();
workflow.nodes.push(buildConnectorNode('connector-1'));
workflow.edges.push({
id: 'e1',
type: 'default',
source: 'missing-node',
sourceHandle: 'value',
target: 'connector-1',
targetHandle: CONNECTOR_INPUT_HANDLE,
});
const validationResult = await validateWorkflow({
workflow,
templates: { img_resize, main_model_loader },
checkImageAccess: resolveTrue,
checkBoardAccess: resolveTrue,
checkModelAccess: resolveTrue,
});
expect(validationResult.workflow.edges).toEqual([]);
expect(validationResult.warnings.length).toBe(1);
});
it('should repair invalid multi-input connector state predictably by keeping the first valid input edge', async () => {
const workflow = getWorkflow();
const loader2 = structuredClone(workflow.nodes[0]!);
loader2.id = 'second-loader';
loader2.data.id = 'second-loader';
const connector = buildConnectorNode('connector-1');
workflow.nodes.push(loader2, connector);
workflow.edges.push({
id: 'e1',
type: 'default',
source: workflow.nodes[0]!.id,
sourceHandle: 'vae',
target: connector.id,
targetHandle: CONNECTOR_INPUT_HANDLE,
});
workflow.edges.push({
id: 'e2',
type: 'default',
source: loader2.id,
sourceHandle: 'vae',
target: connector.id,
targetHandle: CONNECTOR_INPUT_HANDLE,
});
const validationResult = await validateWorkflow({
workflow,
templates: { img_resize, main_model_loader },
checkImageAccess: resolveTrue,
checkBoardAccess: resolveTrue,
checkModelAccess: resolveTrue,
});
expect(validationResult.workflow.edges).toEqual([
{
id: 'e1',
type: 'default',
source: workflow.nodes[0]!.id,
sourceHandle: 'vae',
target: connector.id,
targetHandle: CONNECTOR_INPUT_HANDLE,
},
]);
expect(validationResult.warnings.length).toBe(1);
});
it('should retain isolated connectors during workflow validation', async () => {
const workflow = getWorkflow();
workflow.nodes.push(buildConnectorNode('connector-1'));
const validationResult = await validateWorkflow({
workflow,
templates: { img_resize, main_model_loader },
checkImageAccess: resolveTrue,
checkBoardAccess: resolveTrue,
checkModelAccess: resolveTrue,
});
expect(validationResult.workflow.nodes.find((node) => node.id === 'connector-1')).toBeDefined();
expect(validationResult.warnings).toEqual([]);
});
it('should retain unresolved connector output edges that establish downstream constraints in the editor', async () => {
const workflow = getWorkflow();
workflow.nodes.push(buildConnectorNode('connector-1'));
const unresolvedEdge = {
id: 'e1',
type: 'default' as const,
source: 'connector-1',
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
target: workflow.nodes[1]!.id,
targetHandle: 'image',
};
workflow.edges.push(unresolvedEdge);
const validationResult = await validateWorkflow({
workflow,
templates: { img_resize, main_model_loader },
checkImageAccess: resolveTrue,
checkBoardAccess: resolveTrue,
checkModelAccess: resolveTrue,
});
expect(validationResult.workflow.edges).toEqual([unresolvedEdge]);
expect(validationResult.warnings).toEqual([]);
});
});

View File

@@ -1,6 +1,7 @@
import { parseify } from 'common/util/serialize';
import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation';
import type { Templates } from 'features/nodes/store/types';
import { validateConnection } from 'features/nodes/store/util/validateConnection';
import {
isBoardFieldInputInstance,
isImageFieldCollectionInputInstance,
@@ -149,8 +150,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
}
}
// Stash invalid edges here to be deleted later
const edgesToDelete = new Set<string>();
const validEdges = [];
for (const edge of edges) {
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
@@ -169,7 +169,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
);
}
if (!sourceTemplate) {
if (sourceNode?.type === 'invocation' && !sourceTemplate) {
// The edge's source/output node template does not exist
issues.push(
t('nodes.missingTemplate', {
@@ -198,7 +198,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
);
}
if (!targetTemplate) {
if (targetNode?.type === 'invocation' && !targetTemplate) {
// The edge's target/input node template does not exist
issues.push(
t('nodes.missingTemplate', {
@@ -218,8 +218,14 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
);
}
if (!issues.length && edge.type === 'default') {
const connectionError = validateConnection(edge, nodes, validEdges, templates, null, true);
if (connectionError) {
issues.push(connectionError);
}
}
if (issues.length) {
edgesToDelete.add(edge.id);
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
warnings.push({
@@ -227,11 +233,13 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
issues,
data: edge,
});
} else {
validEdges.push(edge);
}
}
// Remove invalid edges
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
_workflow.edges = validEdges;
// Migrated exposed fields to form elements if they exist and the form does not
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!