mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Merge branch 'main' into external-models
This commit is contained in:
@@ -1449,6 +1449,8 @@
|
||||
"notes": "Notes",
|
||||
"description": "Description",
|
||||
"notesDescription": "Add notes about your workflow",
|
||||
"addConnector": "Add Connector",
|
||||
"deleteConnector": "Delete Connector",
|
||||
"problemSettingTitle": "Problem Setting Title",
|
||||
"resetToDefaultValue": "Reset to default value",
|
||||
"reloadNodeTemplates": "Reload Node Templates",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { Menu, MenuButton, MenuItem, MenuList, Portal, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import type {
|
||||
EdgeChange,
|
||||
@@ -32,7 +32,9 @@ import {
|
||||
$edgePendingUpdate,
|
||||
$lastEdgeUpdateMouseEvent,
|
||||
$pendingConnection,
|
||||
$templates,
|
||||
$viewport,
|
||||
connectorInserted,
|
||||
edgesChanged,
|
||||
nodesChanged,
|
||||
redo,
|
||||
@@ -46,18 +48,24 @@ import {
|
||||
selectNodes,
|
||||
selectNodesSlice,
|
||||
} from 'features/nodes/store/selectors';
|
||||
import { getConnectorDeletionSpliceConnections } from 'features/nodes/store/util/connectorTopology';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
|
||||
import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData';
|
||||
import type { CSSProperties, MouseEvent, RefObject } from 'react';
|
||||
import { memo, useCallback, useMemo, useRef } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react';
|
||||
import { useHotkeys } from 'react-hotkeys-hook';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlugsConnectedBold, PiTrashBold } from 'react-icons/pi';
|
||||
|
||||
import CustomConnectionLine from './connectionLines/CustomConnectionLine';
|
||||
import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge';
|
||||
import InvocationDefaultEdge from './edges/InvocationDefaultEdge';
|
||||
import ConnectorNode from './nodes/Connector/ConnectorNode';
|
||||
import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode';
|
||||
import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper';
|
||||
import NotesNode from './nodes/Notes/NotesNode';
|
||||
@@ -70,6 +78,7 @@ const edgeTypes = {
|
||||
|
||||
const nodeTypes = {
|
||||
invocation: InvocationNodeWrapper,
|
||||
connector: ConnectorNode,
|
||||
current_image: CurrentImageNode,
|
||||
notes: NotesNode,
|
||||
} as const;
|
||||
@@ -81,17 +90,69 @@ const snapGrid: [number, number] = [25, 25];
|
||||
|
||||
const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection;
|
||||
|
||||
type WorkflowContextMenuState =
|
||||
| {
|
||||
kind: 'pane';
|
||||
clientX: number;
|
||||
clientY: number;
|
||||
pageX: number;
|
||||
pageY: number;
|
||||
}
|
||||
| {
|
||||
kind: 'connector';
|
||||
connectorId: string;
|
||||
pageX: number;
|
||||
pageY: number;
|
||||
}
|
||||
| null;
|
||||
|
||||
const getWorkflowContextMenuState = (
|
||||
event: globalThis.MouseEvent,
|
||||
flowWrapper: HTMLDivElement | null
|
||||
): WorkflowContextMenuState => {
|
||||
if (event.shiftKey || !(event.target instanceof Element) || !flowWrapper?.contains(event.target)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const connectorId = event.target.closest<HTMLElement>('[data-connector-node-id]')?.dataset.connectorNodeId;
|
||||
if (connectorId) {
|
||||
return {
|
||||
kind: 'connector',
|
||||
connectorId,
|
||||
pageX: event.pageX,
|
||||
pageY: event.pageY,
|
||||
};
|
||||
}
|
||||
|
||||
const paneTarget = event.target.closest('.react-flow__pane');
|
||||
if (paneTarget && flowWrapper.contains(paneTarget)) {
|
||||
return {
|
||||
kind: 'pane',
|
||||
clientX: event.clientX,
|
||||
clientY: event.clientY,
|
||||
pageX: event.pageX,
|
||||
pageY: event.pageY,
|
||||
};
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
export const Flow = memo(() => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
const nodes = useAppSelector(selectNodes);
|
||||
const edges = useAppSelector(selectEdges);
|
||||
const templates = useStore($templates);
|
||||
const viewport = useStore($viewport);
|
||||
const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid);
|
||||
const selectionMode = useAppSelector(selectSelectionMode);
|
||||
const { onConnectStart, onConnect, onConnectEnd } = useConnection();
|
||||
const flowWrapper = useRef<HTMLDivElement>(null);
|
||||
const pendingNodeInternalsUpdateRef = useRef<string[] | null>(null);
|
||||
const isValidConnection = useIsValidConnection();
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const [contextMenuState, setContextMenuState] = useState<WorkflowContextMenuState>(null);
|
||||
|
||||
useFocusRegion('workflows', flowWrapper);
|
||||
|
||||
@@ -127,9 +188,16 @@ export const Flow = memo(() => {
|
||||
}, []);
|
||||
|
||||
const { onCloseGlobal } = useGlobalMenuClose();
|
||||
const handlePaneClick = useCallback(() => {
|
||||
onCloseGlobal();
|
||||
}, [onCloseGlobal]);
|
||||
const handlePaneClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onPaneClick']> = useCallback(
|
||||
(event) => {
|
||||
if ('button' in event && event.button !== 0) {
|
||||
return;
|
||||
}
|
||||
onCloseGlobal();
|
||||
setContextMenuState(null);
|
||||
},
|
||||
[onCloseGlobal]
|
||||
);
|
||||
|
||||
const onInit: OnInit<AnyNode, AnyEdge> = useCallback((flow) => {
|
||||
$flow.set(flow);
|
||||
@@ -147,6 +215,149 @@ export const Flow = memo(() => {
|
||||
}
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const pendingNodeIds = pendingNodeInternalsUpdateRef.current;
|
||||
if (!pendingNodeIds) {
|
||||
return;
|
||||
}
|
||||
|
||||
pendingNodeInternalsUpdateRef.current = null;
|
||||
|
||||
const frameId = requestAnimationFrame(() => {
|
||||
updateNodeInternals([...new Set(pendingNodeIds)]);
|
||||
});
|
||||
|
||||
return () => {
|
||||
cancelAnimationFrame(frameId);
|
||||
};
|
||||
}, [edges, nodes, updateNodeInternals]);
|
||||
|
||||
const addConnectorAtPaneMenuPosition = useCallback(() => {
|
||||
if (contextMenuState?.kind !== 'pane') {
|
||||
return;
|
||||
}
|
||||
const flow = $flow.get();
|
||||
if (!flow) {
|
||||
return;
|
||||
}
|
||||
const connector = buildConnectorNode(
|
||||
flow.screenToFlowPosition({
|
||||
x: contextMenuState.clientX,
|
||||
y: contextMenuState.clientY,
|
||||
})
|
||||
);
|
||||
dispatch(nodesChanged([{ type: 'add', item: connector }]));
|
||||
setContextMenuState(null);
|
||||
}, [contextMenuState, dispatch]);
|
||||
|
||||
const connectorSpliceConnections = useMemo(
|
||||
() =>
|
||||
contextMenuState?.kind === 'connector'
|
||||
? getConnectorDeletionSpliceConnections(
|
||||
contextMenuState.connectorId,
|
||||
nodes,
|
||||
edges,
|
||||
templates,
|
||||
validateConnection
|
||||
)
|
||||
: null,
|
||||
[contextMenuState, edges, nodes, templates]
|
||||
);
|
||||
|
||||
const deleteConnectorFromContextMenu = useCallback(() => {
|
||||
if (contextMenuState?.kind !== 'connector' || !connectorSpliceConnections) {
|
||||
return;
|
||||
}
|
||||
const connectorEdgeRemovals: EdgeChange<AnyEdge>[] = edges
|
||||
.filter((edge) => edge.source === contextMenuState.connectorId || edge.target === contextMenuState.connectorId)
|
||||
.map((edge) => ({ type: 'remove', id: edge.id }));
|
||||
const spliceEdgeAdditions: EdgeChange<AnyEdge>[] = connectorSpliceConnections.map((connection) => ({
|
||||
type: 'add',
|
||||
item: connectionToEdge(connection),
|
||||
}));
|
||||
|
||||
pendingNodeInternalsUpdateRef.current = [
|
||||
contextMenuState.connectorId,
|
||||
...connectorSpliceConnections.flatMap((connection) => [connection.source, connection.target]),
|
||||
];
|
||||
dispatch(edgesChanged([...connectorEdgeRemovals, ...spliceEdgeAdditions]));
|
||||
dispatch(nodesChanged([{ type: 'remove', id: contextMenuState.connectorId }]));
|
||||
setContextMenuState(null);
|
||||
}, [connectorSpliceConnections, contextMenuState, dispatch, edges]);
|
||||
|
||||
useEffect(() => {
|
||||
const onWindowContextMenu = (event: globalThis.MouseEvent) => {
|
||||
const nextContextMenuState = getWorkflowContextMenuState(event, flowWrapper.current);
|
||||
if (!nextContextMenuState) {
|
||||
return;
|
||||
}
|
||||
|
||||
event.preventDefault();
|
||||
event.stopPropagation();
|
||||
setContextMenuState(nextContextMenuState);
|
||||
};
|
||||
|
||||
window.addEventListener('contextmenu', onWindowContextMenu, { capture: true });
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('contextmenu', onWindowContextMenu, { capture: true });
|
||||
};
|
||||
}, []);
|
||||
|
||||
const renderContextMenu = useCallback(() => {
|
||||
if (contextMenuState?.kind === 'pane') {
|
||||
return (
|
||||
<MenuList visibility="visible">
|
||||
<MenuItem icon={<PiPlugsConnectedBold />} onClick={addConnectorAtPaneMenuPosition}>
|
||||
{t('nodes.addConnector')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
);
|
||||
}
|
||||
|
||||
if (contextMenuState?.kind === 'connector') {
|
||||
return (
|
||||
<MenuList visibility="visible">
|
||||
<MenuItem
|
||||
icon={<PiTrashBold />}
|
||||
onClick={deleteConnectorFromContextMenu}
|
||||
isDisabled={!connectorSpliceConnections}
|
||||
isDestructive
|
||||
>
|
||||
{t('nodes.deleteConnector')}
|
||||
</MenuItem>
|
||||
</MenuList>
|
||||
);
|
||||
}
|
||||
|
||||
return <MenuList visibility="visible" />;
|
||||
}, [addConnectorAtPaneMenuPosition, connectorSpliceConnections, contextMenuState, deleteConnectorFromContextMenu, t]);
|
||||
|
||||
const closeContextMenu = useCallback(() => {
|
||||
setContextMenuState(null);
|
||||
}, []);
|
||||
|
||||
const onEdgeDoubleClick = useCallback<NonNullable<ReactFlowProps['onEdgeDoubleClick']>>(
|
||||
(event, edge) => {
|
||||
if (edge.type !== 'default' || edge.hidden) {
|
||||
return;
|
||||
}
|
||||
const flow = $flow.get();
|
||||
if (!flow) {
|
||||
return;
|
||||
}
|
||||
const connector = buildConnectorNode(
|
||||
flow.screenToFlowPosition({
|
||||
x: event.clientX,
|
||||
y: event.clientY,
|
||||
})
|
||||
);
|
||||
dispatch(connectorInserted({ edgeId: edge.id, connector }));
|
||||
updateNodeInternals([edge.source, edge.target, connector.id]);
|
||||
},
|
||||
[dispatch, updateNodeInternals]
|
||||
);
|
||||
|
||||
// #region Updatable Edges
|
||||
|
||||
/**
|
||||
@@ -209,16 +420,130 @@ export const Flow = memo(() => {
|
||||
|
||||
// #endregion
|
||||
|
||||
const renderedNodes = useMemo(() => nodes, [nodes]);
|
||||
|
||||
const renderedEdges = useMemo(() => edges, [edges]);
|
||||
const contextMenuPosition = contextMenuState ? { x: contextMenuState.pageX, y: contextMenuState.pageY } : null;
|
||||
const contextMenuKey = contextMenuPosition ? `${contextMenuPosition.x}-${contextMenuPosition.y}` : 'closed';
|
||||
|
||||
return (
|
||||
<>
|
||||
<FlowSurface
|
||||
flowWrapper={flowWrapper}
|
||||
viewport={viewport}
|
||||
renderedNodes={renderedNodes}
|
||||
renderedEdges={renderedEdges}
|
||||
onInit={onInit}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
onEdgesChange={onEdgesChange}
|
||||
onReconnect={onReconnect}
|
||||
onReconnectStart={onReconnectStart}
|
||||
onReconnectEnd={onReconnectEnd}
|
||||
onConnectStart={onConnectStart}
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
handleMoveEnd={handleMoveEnd}
|
||||
onEdgeDoubleClick={onEdgeDoubleClick}
|
||||
isValidConnection={isValidConnection}
|
||||
shouldSnapToGrid={shouldSnapToGrid}
|
||||
flowStyles={flowStyles}
|
||||
handlePaneClick={handlePaneClick}
|
||||
selectionMode={selectionMode}
|
||||
/>
|
||||
<Portal>
|
||||
<Menu
|
||||
key={contextMenuKey}
|
||||
isOpen={contextMenuPosition !== null}
|
||||
onClose={closeContextMenu}
|
||||
placement="auto-start"
|
||||
gutter={0}
|
||||
>
|
||||
<MenuButton
|
||||
aria-hidden
|
||||
position="absolute"
|
||||
left={contextMenuPosition?.x ?? -9999}
|
||||
top={contextMenuPosition?.y ?? -9999}
|
||||
w={1}
|
||||
h={1}
|
||||
pointerEvents="none"
|
||||
bg="transparent"
|
||||
/>
|
||||
{renderContextMenu()}
|
||||
</Menu>
|
||||
</Portal>
|
||||
<HotkeyIsolator flowWrapper={flowWrapper} />
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
Flow.displayName = 'Flow';
|
||||
|
||||
type FlowSurfaceProps = {
|
||||
flowWrapper: { current: HTMLDivElement | null };
|
||||
viewport: ReactFlowProps<AnyNode, AnyEdge>['defaultViewport'];
|
||||
renderedNodes: AnyNode[];
|
||||
renderedEdges: AnyEdge[];
|
||||
onInit: OnInit<AnyNode, AnyEdge>;
|
||||
onMouseMove: (event: MouseEvent<HTMLDivElement>) => void;
|
||||
onNodesChange: OnNodesChange<AnyNode>;
|
||||
onEdgesChange: OnEdgesChange<AnyEdge>;
|
||||
onReconnect: OnReconnect;
|
||||
onReconnectStart: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onReconnectStart']>;
|
||||
onReconnectEnd: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onReconnectEnd']>;
|
||||
onConnectStart: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnectStart']>;
|
||||
onConnect: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnect']>;
|
||||
onConnectEnd: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onConnectEnd']>;
|
||||
handleMoveEnd: OnMoveEnd;
|
||||
onEdgeDoubleClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onEdgeDoubleClick']>;
|
||||
isValidConnection: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['isValidConnection']>;
|
||||
shouldSnapToGrid: boolean;
|
||||
flowStyles: CSSProperties;
|
||||
handlePaneClick: NonNullable<ReactFlowProps<AnyNode, AnyEdge>['onPaneClick']>;
|
||||
selectionMode: ReturnType<typeof selectSelectionMode>;
|
||||
};
|
||||
|
||||
const FlowSurface = memo((props: FlowSurfaceProps) => {
|
||||
const {
|
||||
flowWrapper,
|
||||
viewport,
|
||||
renderedNodes,
|
||||
renderedEdges,
|
||||
onInit,
|
||||
onMouseMove,
|
||||
onNodesChange,
|
||||
onEdgesChange,
|
||||
onReconnect,
|
||||
onReconnectStart,
|
||||
onReconnectEnd,
|
||||
onConnectStart,
|
||||
onConnect,
|
||||
onConnectEnd,
|
||||
handleMoveEnd,
|
||||
onEdgeDoubleClick,
|
||||
isValidConnection,
|
||||
shouldSnapToGrid,
|
||||
flowStyles,
|
||||
handlePaneClick,
|
||||
selectionMode,
|
||||
} = props;
|
||||
|
||||
const setFlowWrapperElement = useCallback(
|
||||
(el: HTMLDivElement | null) => {
|
||||
flowWrapper.current = el;
|
||||
},
|
||||
[flowWrapper]
|
||||
);
|
||||
|
||||
return (
|
||||
<div ref={setFlowWrapperElement} style={{ width: '100%', height: '100%' }}>
|
||||
<ReactFlow<AnyNode, AnyEdge>
|
||||
id="workflow-editor"
|
||||
ref={flowWrapper}
|
||||
defaultViewport={viewport}
|
||||
nodeTypes={nodeTypes}
|
||||
edgeTypes={edgeTypes}
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
nodes={renderedNodes}
|
||||
edges={renderedEdges}
|
||||
onInit={onInit}
|
||||
onMouseMove={onMouseMove}
|
||||
onNodesChange={onNodesChange}
|
||||
@@ -230,6 +555,7 @@ export const Flow = memo(() => {
|
||||
onConnect={onConnect}
|
||||
onConnectEnd={onConnectEnd}
|
||||
onMoveEnd={handleMoveEnd}
|
||||
onEdgeDoubleClick={onEdgeDoubleClick}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
isValidConnection={isValidConnection}
|
||||
minZoom={0.1}
|
||||
@@ -249,12 +575,11 @@ export const Flow = memo(() => {
|
||||
>
|
||||
<Background gap={snapGrid} offset={snapGrid} />
|
||||
</ReactFlow>
|
||||
<HotkeyIsolator flowWrapper={flowWrapper} />
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
});
|
||||
|
||||
Flow.displayName = 'Flow';
|
||||
FlowSurface.displayName = 'FlowSurface';
|
||||
|
||||
const HotkeyIsolator = memo(({ flowWrapper }: { flowWrapper: RefObject<HTMLDivElement> }) => {
|
||||
const mayUndo = useAppSelector(selectMayUndo);
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar';
|
||||
import { selectNodes } from 'features/nodes/store/selectors';
|
||||
import { selectEdges, selectNodes } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode } from 'features/nodes/types/invocation';
|
||||
|
||||
import { getFieldColor } from './getEdgeColor';
|
||||
|
||||
@@ -22,23 +23,20 @@ export const buildSelectEdgeColor = (
|
||||
target: string,
|
||||
targetHandleId: string | null | undefined
|
||||
) =>
|
||||
createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => {
|
||||
createSelector(selectNodes, selectEdges, selectWorkflowSettingsSlice, (nodes, edges, workflowSettings): string => {
|
||||
const { shouldColorEdges } = workflowSettings;
|
||||
if (!shouldColorEdges) {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
const sourceNode = nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.find((node) => node.id === target);
|
||||
|
||||
if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) {
|
||||
if (!sourceNode || !sourceHandleId || !targetHandleId) {
|
||||
return colorTokenToCssVar('base.500');
|
||||
}
|
||||
|
||||
const sourceNodeTemplate = templates[sourceNode.data.type];
|
||||
|
||||
const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode);
|
||||
const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId];
|
||||
const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined;
|
||||
const sourceType = isConnectorNode(sourceNode)
|
||||
? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates)
|
||||
: templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type;
|
||||
|
||||
return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500');
|
||||
});
|
||||
@@ -50,7 +48,7 @@ export const buildSelectEdgeLabel = (
|
||||
target: string,
|
||||
targetHandleId: string | null | undefined
|
||||
) =>
|
||||
createSelector(selectNodes, (nodes): string | null => {
|
||||
createSelector(selectNodes, selectEdges, (nodes, edges): string | null => {
|
||||
const sourceNode = nodes.find((node) => node.id === source);
|
||||
const targetNode = nodes.find((node) => node.id === target);
|
||||
|
||||
@@ -58,8 +56,12 @@ export const buildSelectEdgeLabel = (
|
||||
return null;
|
||||
}
|
||||
|
||||
const sourceNodeTemplate = templates[sourceNode.data.type];
|
||||
const resolvedSource = isConnectorNode(sourceNode) ? resolveConnectorSource(sourceNode.id, nodes, edges) : null;
|
||||
const sourceTemplate =
|
||||
resolvedSource !== null
|
||||
? templates[nodes.find((node) => node.id === resolvedSource.nodeId)?.data.type ?? '']
|
||||
: templates[sourceNode.data.type];
|
||||
const targetNodeTemplate = templates[targetNode.data.type];
|
||||
|
||||
return `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
return `${sourceTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`;
|
||||
});
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Icon } from '@invoke-ai/ui-library';
|
||||
import type { Node, NodeProps } from '@xyflow/react';
|
||||
import { Handle, Position } from '@xyflow/react';
|
||||
import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper';
|
||||
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
|
||||
import { NO_DRAG_CLASS } from 'features/nodes/types/constants';
|
||||
import type { ConnectorNodeData } from 'features/nodes/types/invocation';
|
||||
import type { CSSProperties } from 'react';
|
||||
import { memo } from 'react';
|
||||
import { PiDotOutlineFill } from 'react-icons/pi';
|
||||
|
||||
const handleVisualSx = {
|
||||
w: 3,
|
||||
h: 3,
|
||||
borderRadius: 'full',
|
||||
borderWidth: 2,
|
||||
borderColor: 'base.900',
|
||||
bg: 'base.100',
|
||||
pointerEvents: 'none',
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
const handleStyles = {
|
||||
position: 'absolute',
|
||||
width: '1rem',
|
||||
height: '1rem',
|
||||
display: 'flex',
|
||||
alignItems: 'center',
|
||||
justifyContent: 'center',
|
||||
top: '50%',
|
||||
transform: 'translateY(-50%)',
|
||||
zIndex: 1,
|
||||
background: 'none',
|
||||
border: 'none',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
const inputHandleStyles = {
|
||||
...handleStyles,
|
||||
insetInlineStart: 0,
|
||||
justifyContent: 'flex-start',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
const outputHandleStyles = {
|
||||
...handleStyles,
|
||||
insetInlineEnd: 0,
|
||||
justifyContent: 'flex-end',
|
||||
} satisfies CSSProperties;
|
||||
|
||||
const ConnectorNode = ({ id, selected }: NodeProps<Node<ConnectorNodeData>>) => {
|
||||
return (
|
||||
<NonInvocationNodeWrapper nodeId={id} selected={selected} width={25} borderRadius="full" withChrome={false}>
|
||||
<Box
|
||||
data-connector-node-context-menu="true"
|
||||
data-connector-node-id={id}
|
||||
position="relative"
|
||||
w={25}
|
||||
h={25}
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
>
|
||||
<Handle
|
||||
className={NO_DRAG_CLASS}
|
||||
type="target"
|
||||
id={CONNECTOR_INPUT_HANDLE}
|
||||
position={Position.Left}
|
||||
style={inputHandleStyles}
|
||||
>
|
||||
<Box sx={handleVisualSx} />
|
||||
</Handle>
|
||||
<Box
|
||||
w={13}
|
||||
h={13}
|
||||
display="flex"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
borderRadius="full"
|
||||
bg={selected ? 'base.650' : 'base.700'}
|
||||
boxShadow={selected ? '0 0 0 2px var(--invoke-colors-blue-300)' : '0 0 0 1px var(--invoke-colors-base-500)'}
|
||||
>
|
||||
<Icon as={PiDotOutlineFill} boxSize={5} color={selected ? 'base.50' : 'base.100'} />
|
||||
</Box>
|
||||
<Handle
|
||||
className={NO_DRAG_CLASS}
|
||||
type="source"
|
||||
id={CONNECTOR_OUTPUT_HANDLE}
|
||||
position={Position.Right}
|
||||
style={outputHandleStyles}
|
||||
>
|
||||
<Box sx={handleVisualSx} />
|
||||
</Handle>
|
||||
</Box>
|
||||
</NonInvocationNodeWrapper>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(ConnectorNode);
|
||||
@@ -16,10 +16,12 @@ type NonInvocationNodeWrapperProps = PropsWithChildren & {
|
||||
nodeId: string;
|
||||
selected: boolean;
|
||||
width?: ChakraProps['w'];
|
||||
borderRadius?: ChakraProps['borderRadius'];
|
||||
withChrome?: boolean;
|
||||
};
|
||||
|
||||
const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => {
|
||||
const { nodeId, width, children, selected } = props;
|
||||
const { nodeId, width, children, selected, borderRadius = 'base', withChrome = true } = props;
|
||||
const mouseOverNode = useMouseOverNode(nodeId);
|
||||
const zoomToNode = useZoomToNode(nodeId);
|
||||
|
||||
@@ -62,14 +64,15 @@ const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => {
|
||||
onMouseOut={mouseOverNode.handleMouseOut}
|
||||
className={DRAG_HANDLE_CLASSNAME}
|
||||
sx={containerSx}
|
||||
borderRadius={borderRadius}
|
||||
width={width || NODE_WIDTH}
|
||||
opacity={opacity}
|
||||
data-is-selected={selected}
|
||||
>
|
||||
<Box sx={shadowsSx} />
|
||||
<Box sx={inProgressSx} data-is-in-progress={isInProgress} />
|
||||
{withChrome && <Box sx={shadowsSx} />}
|
||||
{withChrome && <Box sx={inProgressSx} data-is-in-progress={isInProgress} />}
|
||||
{children}
|
||||
<Box className="node-selection-overlay" />
|
||||
{withChrome && <Box className="node-selection-overlay" />}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,7 +6,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
export const containerSx: SystemStyleObject = {
|
||||
h: 'full',
|
||||
position: 'relative',
|
||||
borderRadius: 'base',
|
||||
borderRadius: 'inherit',
|
||||
transitionProperty: 'none',
|
||||
cursor: 'grab',
|
||||
'--border-color': 'var(--invoke-colors-base-500)',
|
||||
@@ -30,7 +30,7 @@ export const containerSx: SystemStyleObject = {
|
||||
insetInlineEnd: 0,
|
||||
bottom: 0,
|
||||
insetInlineStart: 0,
|
||||
borderRadius: 'base',
|
||||
borderRadius: 'inherit',
|
||||
transitionProperty: 'none',
|
||||
pointerEvents: 'none',
|
||||
shadow: '0 0 0 1px var(--border-color)',
|
||||
@@ -64,7 +64,7 @@ export const shadowsSx: SystemStyleObject = {
|
||||
insetInlineEnd: 0,
|
||||
bottom: 0,
|
||||
insetInlineStart: 0,
|
||||
borderRadius: 'base',
|
||||
borderRadius: 'inherit',
|
||||
pointerEvents: 'none',
|
||||
zIndex: -1,
|
||||
shadow: 'var(--invoke-shadows-xl), var(--invoke-shadows-base), var(--invoke-shadows-base)',
|
||||
@@ -76,7 +76,7 @@ export const inProgressSx: SystemStyleObject = {
|
||||
insetInlineEnd: 0,
|
||||
bottom: 0,
|
||||
insetInlineStart: 0,
|
||||
borderRadius: 'md',
|
||||
borderRadius: 'inherit',
|
||||
pointerEvents: 'none',
|
||||
transitionProperty: 'none',
|
||||
opacity: 0.7,
|
||||
|
||||
@@ -3,6 +3,7 @@ import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { $flow } from 'features/nodes/store/reactFlowInstance';
|
||||
import { NODE_WIDTH } from 'features/nodes/types/constants';
|
||||
import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
|
||||
import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode';
|
||||
import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode';
|
||||
import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode';
|
||||
@@ -14,7 +15,7 @@ export const useBuildNode = () => {
|
||||
|
||||
return useCallback(
|
||||
// string here is "any invocation type"
|
||||
(type: string | 'current_image' | 'notes'): AnyNode => {
|
||||
(type: string | 'connector' | 'current_image' | 'notes'): AnyNode => {
|
||||
const flow = $flow.get();
|
||||
assert(flow !== null);
|
||||
|
||||
@@ -42,6 +43,10 @@ export const useBuildNode = () => {
|
||||
return buildNotesNode(position);
|
||||
}
|
||||
|
||||
if (type === 'connector') {
|
||||
return buildConnectorNode(position);
|
||||
}
|
||||
|
||||
// TODO: Keep track of invocation types so we do not need to cast this
|
||||
// We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates.
|
||||
const template = templates[type] as InvocationTemplate;
|
||||
|
||||
@@ -12,9 +12,15 @@ import {
|
||||
edgesChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
resolveConnectorSourceFieldType,
|
||||
} from 'features/nodes/store/util/connectorTopology';
|
||||
import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import type { AnyEdge } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode } from 'features/nodes/types/invocation';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
@@ -33,6 +39,47 @@ export const useConnection = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (isConnectorNode(node)) {
|
||||
if (handleType === 'source' && handleId !== CONNECTOR_OUTPUT_HANDLE) {
|
||||
return;
|
||||
}
|
||||
if (handleType === 'target' && handleId !== CONNECTOR_INPUT_HANDLE) {
|
||||
return;
|
||||
}
|
||||
|
||||
const resolvedSourceType =
|
||||
handleType === 'source'
|
||||
? resolveConnectorSourceFieldType(nodeId, nodes, selectNodesSlice(store.getState()).edges, templates)
|
||||
: null;
|
||||
$pendingConnection.set({
|
||||
nodeId,
|
||||
handleId,
|
||||
handleType,
|
||||
fieldTemplate:
|
||||
handleType === 'source'
|
||||
? {
|
||||
name: CONNECTOR_OUTPUT_HANDLE,
|
||||
title: 'Connector Output',
|
||||
description: '',
|
||||
fieldKind: 'output',
|
||||
ui_hidden: false,
|
||||
type: resolvedSourceType ?? { name: 'AnyField', cardinality: 'SINGLE', batch: false },
|
||||
}
|
||||
: {
|
||||
name: CONNECTOR_INPUT_HANDLE,
|
||||
title: 'Connector Input',
|
||||
description: '',
|
||||
fieldKind: 'input',
|
||||
input: 'connection',
|
||||
required: false,
|
||||
default: undefined,
|
||||
ui_hidden: false,
|
||||
type: { name: 'AnyField', cardinality: 'SINGLE', batch: false },
|
||||
},
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const template = templates[node.data.type];
|
||||
if (!template) {
|
||||
return;
|
||||
|
||||
@@ -0,0 +1,160 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice';
|
||||
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology';
|
||||
import { add, buildEdge, buildNode, sub } from './util/testUtils';
|
||||
|
||||
const buildFixedConnectorNode = (id: string) => {
|
||||
const connectorNode = buildConnectorNode({ x: 0, y: 0 });
|
||||
return {
|
||||
...connectorNode,
|
||||
id,
|
||||
data: {
|
||||
...connectorNode.data,
|
||||
id,
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
describe('nodesSlice connector actions', () => {
|
||||
it('splits a direct edge into source -> connector -> target edges when inserting a connector', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connector = buildFixedConnectorNode('connector-1');
|
||||
const directEdge = buildEdge(source.id, 'value', target.id, 'a');
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, target];
|
||||
initialState.edges = [directEdge];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
connectorInserted({
|
||||
edgeId: directEdge.id,
|
||||
connector,
|
||||
})
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id, connector.id]);
|
||||
expect(nextState.edges).toEqual([
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
]);
|
||||
});
|
||||
|
||||
it('splices connector outputs back to the resolved upstream source when removed', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connector = buildFixedConnectorNode('connector-1');
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, connector, target];
|
||||
initialState.edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
nodesChanged([{ type: 'remove', id: connector.id }])
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]);
|
||||
expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]);
|
||||
});
|
||||
|
||||
it('splices one connector source back to multiple downstream targets when removed', () => {
|
||||
const source = buildNode(add);
|
||||
const targetA = buildNode(sub);
|
||||
const targetB = buildNode(sub);
|
||||
const connector = buildFixedConnectorNode('connector-1');
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, connector, targetA, targetB];
|
||||
initialState.edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'b'),
|
||||
];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
nodesChanged([{ type: 'remove', id: connector.id }])
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, targetA.id, targetB.id]);
|
||||
expect(nextState.edges).toEqual([
|
||||
buildEdge(source.id, 'value', targetA.id, 'a'),
|
||||
buildEdge(source.id, 'value', targetB.id, 'b'),
|
||||
]);
|
||||
});
|
||||
|
||||
it('does not create any edges when removing a connector with no downstream targets', () => {
|
||||
const source = buildNode(add);
|
||||
const connector = buildFixedConnectorNode('connector-1');
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, connector];
|
||||
initialState.edges = [buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
nodesChanged([{ type: 'remove', id: connector.id }])
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id]);
|
||||
expect(nextState.edges).toEqual([]);
|
||||
});
|
||||
|
||||
it('removes a connector while preserving downstream connector edges in a chained splice case', () => {
|
||||
const source = buildNode(add);
|
||||
const connectorA = buildFixedConnectorNode('connector-a');
|
||||
const connectorB = buildFixedConnectorNode('connector-b');
|
||||
const target = buildNode(sub);
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, connectorA, connectorB, target];
|
||||
initialState.edges = [
|
||||
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
nodesChanged([{ type: 'remove', id: connectorA.id }])
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, connectorB.id, target.id]);
|
||||
expect(nextState.edges).toHaveLength(2);
|
||||
expect(nextState.edges).toEqual(
|
||||
expect.arrayContaining([
|
||||
buildEdge(source.id, 'value', connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
])
|
||||
);
|
||||
});
|
||||
|
||||
it('splices connector edges when the connector is removed through generic node removal', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connector = buildFixedConnectorNode('connector-1');
|
||||
|
||||
const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' }));
|
||||
initialState.nodes = [source, connector, target];
|
||||
initialState.edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
];
|
||||
|
||||
const nextState = nodesSliceConfig.slice.reducer(
|
||||
initialState,
|
||||
nodesChanged([{ type: 'remove', id: connector.id }])
|
||||
);
|
||||
|
||||
expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]);
|
||||
expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]);
|
||||
});
|
||||
});
|
||||
@@ -20,6 +20,13 @@ import {
|
||||
reparentElement,
|
||||
} from 'features/nodes/components/sidePanel/builder/form-manipulation';
|
||||
import { type NodesState, zNodesState } from 'features/nodes/store/types';
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
getConnectorOutputEdges,
|
||||
resolveConnectorSource,
|
||||
} from 'features/nodes/store/util/connectorTopology';
|
||||
import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type {
|
||||
BoardFieldValue,
|
||||
@@ -65,8 +72,8 @@ import {
|
||||
zStringGeneratorFieldValue,
|
||||
zStylePresetFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { AnyEdge, AnyNode, ConnectorNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type {
|
||||
BuilderForm,
|
||||
ContainerElement,
|
||||
@@ -103,7 +110,7 @@ export const getInitialWorkflow = (): Omit<NodesState, 'mode' | 'formFieldInitia
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
meta: { version: '4.0.0', category: 'user' },
|
||||
form: getDefaultForm(),
|
||||
nodes: [],
|
||||
edges: [],
|
||||
@@ -175,6 +182,33 @@ const slice = createSlice({
|
||||
initialState: getInitialState(),
|
||||
reducers: {
|
||||
nodesChanged: (state, action: PayloadAction<NodeChange<AnyNode>[]>) => {
|
||||
const removedConnectorSpliceEdges: AnyEdge[] = action.payload.flatMap((change) => {
|
||||
if (change.type !== 'remove') {
|
||||
return [];
|
||||
}
|
||||
|
||||
const node = state.nodes.find((candidate) => candidate.id === change.id);
|
||||
if (!isConnectorNode(node)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const resolvedSource = resolveConnectorSource(node.id, state.nodes, state.edges);
|
||||
if (!resolvedSource) {
|
||||
return [];
|
||||
}
|
||||
|
||||
return getConnectorOutputEdges(node.id, state.edges)
|
||||
.filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default')
|
||||
.map((edge) =>
|
||||
connectionToEdge({
|
||||
source: resolvedSource.nodeId,
|
||||
sourceHandle: resolvedSource.fieldName,
|
||||
target: edge.target,
|
||||
targetHandle: edge.targetHandle,
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// TODO(psyche): The below TS issue was recently fixed upstream. Need to upgrade @xyflow/react and then we
|
||||
// should be able to remove this cast.
|
||||
//
|
||||
@@ -206,6 +240,12 @@ const slice = createSlice({
|
||||
if (edgeChanges.length > 0) {
|
||||
state.edges = applyEdgeChanges<AnyEdge>(edgeChanges, state.edges);
|
||||
}
|
||||
if (removedConnectorSpliceEdges.length > 0) {
|
||||
state.edges = applyEdgeChanges<AnyEdge>(
|
||||
removedConnectorSpliceEdges.map((edge) => ({ type: 'add', item: edge })),
|
||||
state.edges
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const wereNodesRemoved = action.payload.some((change) => change.type === 'remove' || change.type === 'replace');
|
||||
@@ -396,11 +436,40 @@ const slice = createSlice({
|
||||
}
|
||||
}
|
||||
},
|
||||
connectorInserted: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
edgeId: string;
|
||||
connector: ConnectorNode;
|
||||
}>
|
||||
) => {
|
||||
const { edgeId, connector } = action.payload;
|
||||
const edge = state.edges.find((candidate) => candidate.id === edgeId);
|
||||
if (!edge || edge.type !== 'default') {
|
||||
return;
|
||||
}
|
||||
state.nodes.push({ ...SHARED_NODE_PROPERTIES, ...connector } as (typeof state.nodes)[number]);
|
||||
state.edges = state.edges.filter((candidate) => candidate.id !== edgeId);
|
||||
state.edges.push(
|
||||
connectionToEdge({
|
||||
source: edge.source,
|
||||
sourceHandle: edge.sourceHandle ?? null,
|
||||
target: connector.id,
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
}),
|
||||
connectionToEdge({
|
||||
source: connector.id,
|
||||
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
|
||||
target: edge.target,
|
||||
targetHandle: edge.targetHandle ?? null,
|
||||
})
|
||||
);
|
||||
},
|
||||
nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => {
|
||||
const { nodeId, label } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (isInvocationNode(node) || isNotesNode(node)) {
|
||||
if (isInvocationNode(node) || isNotesNode(node) || isConnectorNode(node)) {
|
||||
node.data.label = label;
|
||||
}
|
||||
},
|
||||
@@ -614,6 +683,7 @@ export const {
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
connectorInserted,
|
||||
nodeLabelChanged,
|
||||
nodeNotesChanged,
|
||||
nodesChanged,
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
import type { AnyNode, ConnectorNode } from 'features/nodes/types/invocation';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
getConnectorDeletionSpliceConnections,
|
||||
getConnectorInputEdge,
|
||||
getConnectorOutputEdges,
|
||||
resolveConnectorSource,
|
||||
resolveConnectorSourceFieldType,
|
||||
} from './connectorTopology';
|
||||
import { add, buildEdge, buildNode, img_resize, sub, templates } from './testUtils';
|
||||
|
||||
const buildConnectorNode = (id: string): ConnectorNode => ({
|
||||
id,
|
||||
type: 'connector',
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id,
|
||||
type: 'connector',
|
||||
label: 'Connector',
|
||||
isOpen: true,
|
||||
},
|
||||
});
|
||||
|
||||
describe('connectorTopology', () => {
|
||||
it('resolves the effective upstream source through one connector', () => {
|
||||
const source = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const target = buildNode(sub);
|
||||
const nodes: AnyNode[] = [source, connector, target];
|
||||
const edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
];
|
||||
|
||||
expect(resolveConnectorSource(connector.id, nodes, edges)).toEqual({
|
||||
nodeId: source.id,
|
||||
fieldName: 'value',
|
||||
});
|
||||
expect(resolveConnectorSourceFieldType(connector.id, nodes, edges, templates)).toEqual(add.outputs.value?.type);
|
||||
});
|
||||
|
||||
it('resolves the effective upstream source through chained connectors', () => {
|
||||
const source = buildNode(add);
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const nodes: AnyNode[] = [source, connectorA, connectorB];
|
||||
const edges = [
|
||||
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
];
|
||||
|
||||
expect(resolveConnectorSource(connectorB.id, nodes, edges)).toEqual({
|
||||
nodeId: source.id,
|
||||
fieldName: 'value',
|
||||
});
|
||||
});
|
||||
|
||||
it('returns no source or type for an unresolved connector chain', () => {
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const nodes: AnyNode[] = [connectorA, connectorB];
|
||||
const edges = [buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE)];
|
||||
|
||||
expect(resolveConnectorSource(connectorB.id, nodes, edges)).toBe(null);
|
||||
expect(resolveConnectorSourceFieldType(connectorB.id, nodes, edges, templates)).toBe(null);
|
||||
});
|
||||
|
||||
it('enumerates multiple outgoing edges for a connector', () => {
|
||||
const source = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const targetA = buildNode(sub);
|
||||
const targetB = buildNode(img_resize);
|
||||
const incoming = buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE);
|
||||
const outgoingA = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a');
|
||||
const outgoingB = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width');
|
||||
const edges = [incoming, outgoingA, outgoingB];
|
||||
|
||||
expect(getConnectorInputEdge(connector.id, edges)).toEqual(incoming);
|
||||
expect(getConnectorOutputEdges(connector.id, edges)).toEqual([outgoingA, outgoingB]);
|
||||
});
|
||||
|
||||
it('rejects connector deletion splice-through when any downstream target would be invalid', () => {
|
||||
const source = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const target = buildNode(img_resize);
|
||||
const nodes: AnyNode[] = [source, connector, target];
|
||||
const edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'image'),
|
||||
];
|
||||
|
||||
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null);
|
||||
});
|
||||
|
||||
it('builds connector deletion splice-through edges when every downstream target remains valid', () => {
|
||||
const source = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const target = buildNode(sub);
|
||||
const nodes: AnyNode[] = [source, connector, target];
|
||||
const edges = [
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
];
|
||||
|
||||
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toEqual([
|
||||
{
|
||||
source: source.id,
|
||||
sourceHandle: 'value',
|
||||
target: target.id,
|
||||
targetHandle: 'a',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('returns no splice-through edges when a connector has downstream targets but no upstream source', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const target = buildNode(sub);
|
||||
const nodes: AnyNode[] = [connector, target];
|
||||
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')];
|
||||
|
||||
expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null);
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,228 @@
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
|
||||
export const CONNECTOR_INPUT_HANDLE = 'in';
|
||||
export const CONNECTOR_OUTPUT_HANDLE = 'out';
|
||||
|
||||
type ResolvedConnectorSource = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
};
|
||||
|
||||
type SpliceConnection = {
|
||||
source: string;
|
||||
sourceHandle: string;
|
||||
target: string;
|
||||
targetHandle: string;
|
||||
};
|
||||
|
||||
type SpliceConnectionValidator = (
|
||||
connection: SpliceConnection,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[],
|
||||
templates: Templates,
|
||||
ignoreEdge: AnyEdge | null,
|
||||
strict?: boolean
|
||||
) => string | null;
|
||||
|
||||
export const getConnectorInputEdge = (connectorId: string, edges: AnyEdge[]): AnyEdge | undefined =>
|
||||
edges.find(
|
||||
(edge) =>
|
||||
edge.type === 'default' &&
|
||||
edge.target === connectorId &&
|
||||
edge.targetHandle === CONNECTOR_INPUT_HANDLE &&
|
||||
typeof edge.sourceHandle === 'string'
|
||||
);
|
||||
|
||||
export const getConnectorOutputEdges = (connectorId: string, edges: AnyEdge[]): AnyEdge[] =>
|
||||
edges.filter(
|
||||
(edge) =>
|
||||
edge.type === 'default' &&
|
||||
edge.source === connectorId &&
|
||||
edge.sourceHandle === CONNECTOR_OUTPUT_HANDLE &&
|
||||
typeof edge.targetHandle === 'string'
|
||||
);
|
||||
|
||||
export const resolveConnectorSource = (
|
||||
connectorId: string,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[]
|
||||
): ResolvedConnectorSource | null => {
|
||||
const visited = new Set<string>();
|
||||
|
||||
const resolve = (nodeId: string): ResolvedConnectorSource | null => {
|
||||
if (visited.has(nodeId)) {
|
||||
return null;
|
||||
}
|
||||
visited.add(nodeId);
|
||||
|
||||
const incomingEdge = getConnectorInputEdge(nodeId, edges);
|
||||
if (!incomingEdge || incomingEdge.type !== 'default') {
|
||||
return null;
|
||||
}
|
||||
if (typeof incomingEdge.sourceHandle !== 'string') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((node) => node.id === incomingEdge.source);
|
||||
if (!sourceNode) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isInvocationNode(sourceNode)) {
|
||||
return { nodeId: sourceNode.id, fieldName: incomingEdge.sourceHandle };
|
||||
}
|
||||
|
||||
if (isConnectorNode(sourceNode)) {
|
||||
return resolve(sourceNode.id);
|
||||
}
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
return resolve(connectorId);
|
||||
};
|
||||
|
||||
export const resolveConnectorSourceFieldType = (
|
||||
connectorId: string,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[],
|
||||
templates: Templates
|
||||
): FieldType | null => {
|
||||
const resolvedSource = resolveConnectorSource(connectorId, nodes, edges);
|
||||
if (!resolvedSource) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((node) => node.id === resolvedSource.nodeId);
|
||||
if (!sourceNode || !isInvocationNode(sourceNode)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
return sourceTemplate?.outputs[resolvedSource.fieldName]?.type ?? null;
|
||||
};
|
||||
|
||||
export const getConnectorDeletionSpliceConnections = (
|
||||
connectorId: string,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[],
|
||||
templates: Templates,
|
||||
validateConnection?: SpliceConnectionValidator
|
||||
): SpliceConnection[] | null => {
|
||||
const resolvedSource = resolveConnectorSource(connectorId, nodes, edges);
|
||||
if (!resolvedSource) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const outputEdges = getConnectorOutputEdges(connectorId, edges);
|
||||
const spliceConnections = outputEdges
|
||||
.filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default')
|
||||
.map((edge) => ({
|
||||
source: resolvedSource.nodeId,
|
||||
sourceHandle: resolvedSource.fieldName,
|
||||
target: edge.target,
|
||||
targetHandle: edge.targetHandle,
|
||||
}));
|
||||
|
||||
const deduped = new Set<string>();
|
||||
for (const connection of spliceConnections) {
|
||||
const key = `${connection.source}:${connection.sourceHandle}->${connection.target}:${connection.targetHandle}`;
|
||||
if (deduped.has(key)) {
|
||||
return null;
|
||||
}
|
||||
deduped.add(key);
|
||||
}
|
||||
|
||||
if (!validateConnection) {
|
||||
const sourceType = resolveConnectorSourceFieldType(connectorId, nodes, edges, templates);
|
||||
if (!sourceType) {
|
||||
return null;
|
||||
}
|
||||
const inputEdgeId = getConnectorInputEdge(connectorId, edges)?.id;
|
||||
const outputEdgeIds = new Set(outputEdges.map((edge) => edge.id));
|
||||
|
||||
for (const connection of spliceConnections) {
|
||||
const targetNode = nodes.find((node) => node.id === connection.target);
|
||||
if (!targetNode || !isInvocationNode(targetNode)) {
|
||||
return null;
|
||||
}
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
const targetFieldTemplate = targetTemplate?.inputs[connection.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const matchesExistingDirectEdge = edges.some(
|
||||
(edge) =>
|
||||
edge.type === 'default' &&
|
||||
edge.source === connection.source &&
|
||||
edge.sourceHandle === connection.sourceHandle &&
|
||||
edge.target === connection.target &&
|
||||
edge.targetHandle === connection.targetHandle
|
||||
);
|
||||
if (matchesExistingDirectEdge) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const targetConflictCount = spliceConnections.filter(
|
||||
(candidate) => candidate.target === connection.target && candidate.targetHandle === connection.targetHandle
|
||||
).length;
|
||||
const existingTargetConflict = edges.some(
|
||||
(edge) =>
|
||||
edge.type === 'default' &&
|
||||
edge.id !== inputEdgeId &&
|
||||
!outputEdgeIds.has(edge.id) &&
|
||||
edge.target === connection.target &&
|
||||
edge.targetHandle === connection.targetHandle
|
||||
);
|
||||
if (
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField' &&
|
||||
(targetConflictCount > 1 || existingTargetConflict)
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (
|
||||
sourceType.name !== targetFieldTemplate.type.name &&
|
||||
targetFieldTemplate.type.name !== 'CollectionItemField'
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
return spliceConnections;
|
||||
}
|
||||
|
||||
const ignoredEdgeIds = new Set([
|
||||
getConnectorInputEdge(connectorId, edges)?.id,
|
||||
...outputEdges.map((edge) => edge.id),
|
||||
]);
|
||||
const existingEdges = edges.filter((edge) => !ignoredEdgeIds.has(edge.id));
|
||||
const stagedConnections: SpliceConnection[] = [];
|
||||
|
||||
for (const connection of spliceConnections) {
|
||||
const stagedEdges = [
|
||||
...existingEdges,
|
||||
...stagedConnections.map(
|
||||
({ source, sourceHandle, target, targetHandle }) =>
|
||||
({
|
||||
id: `splice-${source}-${sourceHandle}-${target}-${targetHandle}`,
|
||||
type: 'default',
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
targetHandle,
|
||||
}) satisfies AnyEdge
|
||||
),
|
||||
];
|
||||
if (validateConnection(connection, nodes, stagedEdges, templates, null, true) !== null) {
|
||||
return null;
|
||||
}
|
||||
stagedConnections.push(connection);
|
||||
}
|
||||
|
||||
return spliceConnections;
|
||||
};
|
||||
@@ -1,13 +1,26 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { unset } from 'es-toolkit/compat';
|
||||
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
|
||||
import {
|
||||
getFirstValidConnection,
|
||||
getSourceCandidateFields,
|
||||
getTargetCandidateFields,
|
||||
} from 'features/nodes/store/util/getFirstValidConnection';
|
||||
import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
const buildConnectorNode = (id: string) => ({
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
label: 'Connector',
|
||||
isOpen: true,
|
||||
},
|
||||
});
|
||||
|
||||
describe('getFirstValidConnection', () => {
|
||||
it('should return null if the pending and candidate nodes are the same node', () => {
|
||||
const n = buildNode(add);
|
||||
@@ -120,6 +133,33 @@ describe('getFirstValidConnection', () => {
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
});
|
||||
|
||||
it('should resolve connector target candidates when connecting an invocation output to a connector', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
expect(getFirstValidConnection(n1.id, 'value', connector.id, null, [n1, connector], [], templates, null)).toEqual({
|
||||
source: n1.id,
|
||||
sourceHandle: 'value',
|
||||
target: connector.id,
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
});
|
||||
});
|
||||
|
||||
it('should resolve connector source candidates when connecting a connector to a typed invocation input', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
|
||||
expect(
|
||||
getFirstValidConnection(connector.id, null, n2.id, 'width', [n1, connector, n2], edges, templates, null)
|
||||
).toEqual({
|
||||
source: connector.id,
|
||||
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
|
||||
target: n2.id,
|
||||
targetHandle: 'width',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTargetCandidateFields', () => {
|
||||
@@ -160,6 +200,62 @@ describe('getTargetCandidateFields', () => {
|
||||
const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate);
|
||||
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
|
||||
});
|
||||
it('should return the connector input handle when the target is a connector', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const r = getTargetCandidateFields(n1.id, 'value', connector.id, [n1, connector], [], templates, null);
|
||||
expect(r.map((field) => field.name)).toEqual([CONNECTOR_INPUT_HANDLE]);
|
||||
});
|
||||
it('should advertise typed target candidates for an unresolved connector output when no downstream constraint exists', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(sub);
|
||||
const r = getTargetCandidateFields(
|
||||
connector.id,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
n2.id,
|
||||
[connector, n2],
|
||||
[],
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r.map((field) => field.name)).toEqual(['a', 'b']);
|
||||
});
|
||||
it('should only advertise compatible typed target candidates for an unresolved connector output with downstream constraints', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n1 = buildNode(sub);
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')];
|
||||
const r = getTargetCandidateFields(
|
||||
connector.id,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
n2.id,
|
||||
[connector, n1, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r.map((field) => field.name)).toEqual(['width', 'height']);
|
||||
});
|
||||
it('should resolve chained connector sources like the direct upstream source', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
];
|
||||
const r = getTargetCandidateFields(
|
||||
connectorB.id,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
n2.id,
|
||||
[n1, connectorA, connectorB, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSourceCandidateFields', () => {
|
||||
@@ -200,4 +296,18 @@ describe('getSourceCandidateFields', () => {
|
||||
const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate);
|
||||
expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]);
|
||||
});
|
||||
it('should return the connector output handle when the source is a connector with a typed upstream source', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
const r = getSourceCandidateFields(n2.id, 'width', connector.id, [n1, connector, n2], edges, templates, null);
|
||||
expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]);
|
||||
});
|
||||
it('should return a target-constrained connector source candidate when the connector chain is unresolved', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(img_resize);
|
||||
const r = getSourceCandidateFields(n2.id, 'width', connector.id, [connector, n2], [], templates, null);
|
||||
expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import type { Connection } from '@xyflow/react';
|
||||
import { map } from 'es-toolkit/compat';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
resolveConnectorSourceFieldType,
|
||||
} from 'features/nodes/store/util/connectorTopology';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode } from 'features/nodes/types/invocation';
|
||||
|
||||
/**
|
||||
*
|
||||
@@ -91,16 +97,43 @@ export const getTargetCandidateFields = (
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (isConnectorNode(targetNode)) {
|
||||
const candidate = {
|
||||
name: CONNECTOR_INPUT_HANDLE,
|
||||
title: 'Connector Input',
|
||||
description: '',
|
||||
fieldKind: 'input',
|
||||
input: 'connection',
|
||||
required: false,
|
||||
default: undefined,
|
||||
ui_hidden: false,
|
||||
type: {
|
||||
name: 'AnyField',
|
||||
cardinality: 'SINGLE',
|
||||
batch: false,
|
||||
},
|
||||
} satisfies FieldInputTemplate;
|
||||
|
||||
const c = { source, sourceHandle, target, targetHandle: candidate.name };
|
||||
return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : [];
|
||||
}
|
||||
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!sourceTemplate || !targetTemplate) {
|
||||
if (!targetTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceField = sourceTemplate.outputs[sourceHandle];
|
||||
if (!isConnectorNode(sourceNode)) {
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (!sourceField) {
|
||||
return [];
|
||||
const sourceField = sourceTemplate.outputs[sourceHandle];
|
||||
|
||||
if (!sourceField) {
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
const targetCandidateFields = map(targetTemplate.inputs).filter((field) => {
|
||||
@@ -127,15 +160,45 @@ export const getSourceCandidateFields = (
|
||||
return [];
|
||||
}
|
||||
|
||||
if (isConnectorNode(sourceNode)) {
|
||||
const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates);
|
||||
const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null;
|
||||
const targetFieldType = targetTemplate?.inputs[targetHandle]?.type;
|
||||
const candidateType = sourceFieldType ?? targetFieldType;
|
||||
if (!candidateType) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const candidate = {
|
||||
name: CONNECTOR_OUTPUT_HANDLE,
|
||||
title: 'Connector Output',
|
||||
description: '',
|
||||
fieldKind: 'output',
|
||||
ui_hidden: false,
|
||||
type: candidateType,
|
||||
} satisfies FieldOutputTemplate;
|
||||
|
||||
const c = { source, sourceHandle: candidate.name, target, targetHandle };
|
||||
return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : [];
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!sourceTemplate || !targetTemplate) {
|
||||
if (!sourceTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const targetField = targetTemplate.inputs[targetHandle];
|
||||
if (!isConnectorNode(targetNode)) {
|
||||
const targetTemplate = templates[targetNode.data.type];
|
||||
if (!targetTemplate) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (!targetField) {
|
||||
const targetField = targetTemplate.inputs[targetHandle];
|
||||
|
||||
if (!targetField) {
|
||||
return [];
|
||||
}
|
||||
} else if (targetHandle !== CONNECTOR_INPUT_HANDLE) {
|
||||
return [];
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { connectionToEdge } from './reactFlowUtil';
|
||||
|
||||
describe('connectionToEdge', () => {
|
||||
it('creates a default edge with the expected id and endpoints', () => {
|
||||
expect(
|
||||
connectionToEdge({
|
||||
source: 'source-node',
|
||||
sourceHandle: 'value',
|
||||
target: 'target-node',
|
||||
targetHandle: 'a',
|
||||
})
|
||||
).toEqual({
|
||||
type: 'default',
|
||||
source: 'source-node',
|
||||
sourceHandle: 'value',
|
||||
target: 'target-node',
|
||||
targetHandle: 'a',
|
||||
id: 'reactflow__edge-source-nodevalue-target-nodea',
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -24,6 +24,7 @@ export const connectionToEdge = (connection: Connection): AnyEdge => {
|
||||
const { source, sourceHandle, target, targetHandle } = connection;
|
||||
assert(source && sourceHandle && target && targetHandle, 'Invalid connection');
|
||||
return {
|
||||
type: 'default',
|
||||
source,
|
||||
sourceHandle,
|
||||
target,
|
||||
|
||||
@@ -3,6 +3,11 @@ import { set } from 'es-toolkit/compat';
|
||||
import type { InvocationTemplate } from 'features/nodes/types/invocation';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
getConnectorDeletionSpliceConnections,
|
||||
} from './connectorTopology';
|
||||
import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils';
|
||||
import { validateConnection } from './validateConnection';
|
||||
|
||||
@@ -139,6 +144,18 @@ const integerCollectionOutputTemplate: InvocationTemplate = {
|
||||
classification: 'stable',
|
||||
};
|
||||
|
||||
const buildConnectorNode = (id: string) => ({
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
label: 'Connector',
|
||||
isOpen: true,
|
||||
},
|
||||
});
|
||||
|
||||
describe(validateConnection.name, () => {
|
||||
it('should reject invalid connection to self', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
@@ -458,6 +475,214 @@ describe(validateConnection.name, () => {
|
||||
expect(r).toEqual('nodes.connectionWouldCreateCycle');
|
||||
});
|
||||
|
||||
describe('connectors', () => {
|
||||
it('should accept invocation output to connector input', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const r = validateConnection(
|
||||
{ source: n1.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
|
||||
[n1, connector],
|
||||
[],
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject a second input into a connector', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
const r = validateConnection(
|
||||
{ source: n2.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
|
||||
[n1, n2, connector],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection');
|
||||
});
|
||||
|
||||
it('should accept connector output to invocation input when the upstream type matches', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(sub);
|
||||
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
const r = validateConnection(
|
||||
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
|
||||
[n1, connector, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject connector output to invocation input when the upstream type mismatches', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)];
|
||||
const r = validateConnection(
|
||||
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' },
|
||||
[n1, connector, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.fieldTypesMustMatch');
|
||||
});
|
||||
|
||||
it('should accept unresolved connector output to a typed invocation input as the first downstream constraint', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n2 = buildNode(sub);
|
||||
const r = validateConnection(
|
||||
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
|
||||
[connector, n2],
|
||||
[],
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject unresolved connector output when it conflicts with an existing downstream typed constraint', () => {
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const n1 = buildNode(sub);
|
||||
const n2 = buildNode(img_resize);
|
||||
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')];
|
||||
const r = validateConnection(
|
||||
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' },
|
||||
[connector, n1, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.fieldTypesMustMatch');
|
||||
});
|
||||
|
||||
it('should reject connecting an incompatible upstream source into a connector with downstream typed constraints', () => {
|
||||
const source = buildNode(main_model_loader);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const target = buildNode(sub);
|
||||
const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')];
|
||||
const r = validateConnection(
|
||||
{ source: source.id, sourceHandle: 'vae', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE },
|
||||
[source, connector, target],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.fieldTypesMustMatch');
|
||||
});
|
||||
|
||||
it('should preserve type information through chained connectors', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const n2 = buildNode(sub);
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
];
|
||||
const r = validateConnection(
|
||||
{ source: connectorB.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' },
|
||||
[n1, connectorA, connectorB, n2],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual(null);
|
||||
});
|
||||
|
||||
it('should reject cycles routed through connectors', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
|
||||
];
|
||||
const r = validateConnection(
|
||||
{ source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' },
|
||||
[n1, n2, connector],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.connectionWouldCreateCycle');
|
||||
});
|
||||
|
||||
it('should preserve collect item validation through connectors', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(collect);
|
||||
const n3 = buildNode(main_model_loader);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', n2.id, 'item'),
|
||||
buildEdge(n3.id, 'vae', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
];
|
||||
const r = validateConnection(
|
||||
{ source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'item' },
|
||||
[n1, n2, n3, connector],
|
||||
edges,
|
||||
templates,
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes');
|
||||
});
|
||||
|
||||
it('should preserve if branch validation through connectors', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(img_resize);
|
||||
const n3 = buildNode(ifTemplate);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n3.id, 'true_input'),
|
||||
];
|
||||
const r = validateConnection(
|
||||
{ source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' },
|
||||
[n1, n2, n3, connector],
|
||||
edges,
|
||||
{ ...templates, if: ifTemplate },
|
||||
null
|
||||
);
|
||||
expect(r).toEqual('nodes.fieldTypesMustMatch');
|
||||
});
|
||||
|
||||
it('should reject connector deletion splice-through when it would duplicate an existing direct edge', () => {
|
||||
const n1 = buildNode(add);
|
||||
const n2 = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
|
||||
buildEdge(n1.id, 'value', n2.id, 'a'),
|
||||
];
|
||||
|
||||
expect(getConnectorDeletionSpliceConnections(connector.id, [n1, n2, connector], edges, templates)).toBe(null);
|
||||
});
|
||||
|
||||
it('should reject connector deletion splice-through when fan-out would violate a single-input target', () => {
|
||||
const n1 = buildNode(add);
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const n2 = buildNode(sub);
|
||||
const edges = [
|
||||
buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
|
||||
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'),
|
||||
];
|
||||
|
||||
expect(
|
||||
getConnectorDeletionSpliceConnections(connectorA.id, [n1, connectorA, connectorB, n2], edges, templates)
|
||||
).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe('non-strict mode', () => {
|
||||
it('should reject connections from self to self in non-strict mode', () => {
|
||||
const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' };
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import type { Connection as NullableConnection } from '@xyflow/react';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual';
|
||||
import {
|
||||
CONNECTOR_INPUT_HANDLE,
|
||||
CONNECTOR_OUTPUT_HANDLE,
|
||||
resolveConnectorSource,
|
||||
} from 'features/nodes/store/util/connectorTopology';
|
||||
import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType';
|
||||
import { getHasCycles } from 'features/nodes/store/util/getHasCycles';
|
||||
import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes';
|
||||
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
|
||||
import type { FieldType } from 'features/nodes/types/field';
|
||||
import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { SetNonNullable } from 'type-fest';
|
||||
|
||||
type Connection = SetNonNullable<NullableConnection>;
|
||||
@@ -18,6 +25,12 @@ type ValidateConnectionFunc = (
|
||||
strict?: boolean
|
||||
) => string | null;
|
||||
|
||||
type EffectiveSource = {
|
||||
node: InvocationNode;
|
||||
handle: string;
|
||||
fieldTemplate: NonNullable<Templates[string]>['outputs'][string];
|
||||
};
|
||||
|
||||
const getEqualityPredicate =
|
||||
(c: Connection) =>
|
||||
(e: AnyEdge): boolean => {
|
||||
@@ -41,10 +54,7 @@ const isIfInputHandle = (handle: string): handle is (typeof IF_INPUT_HANDLES)[nu
|
||||
return IF_INPUT_HANDLES.includes(handle as (typeof IF_INPUT_HANDLES)[number]);
|
||||
};
|
||||
|
||||
const isSingleCollectionPairOfSameBaseType = (
|
||||
firstType: { name: string; cardinality: string; batch: boolean },
|
||||
secondType: { name: string; cardinality: string; batch: boolean }
|
||||
) => {
|
||||
const isSingleCollectionPairOfSameBaseType = (firstType: FieldType, secondType: FieldType) => {
|
||||
const isSingleToCollection =
|
||||
firstType.cardinality === 'SINGLE' && secondType.cardinality === 'COLLECTION' && firstType.name === secondType.name;
|
||||
const isCollectionToSingle =
|
||||
@@ -52,6 +62,139 @@ const isSingleCollectionPairOfSameBaseType = (
|
||||
return firstType.batch === secondType.batch && (isSingleToCollection || isCollectionToSingle);
|
||||
};
|
||||
|
||||
const areFieldTypesCompatible = (firstType: FieldType, secondType: FieldType) =>
|
||||
validateConnectionTypes(firstType, secondType) ||
|
||||
validateConnectionTypes(secondType, firstType) ||
|
||||
isSingleCollectionPairOfSameBaseType(firstType, secondType);
|
||||
|
||||
type ConnectorTerminalTargetEdge = AnyEdge & {
|
||||
type: 'default';
|
||||
sourceHandle: string;
|
||||
targetHandle: string;
|
||||
};
|
||||
|
||||
const getConnectorTerminalTargetEdges = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => {
|
||||
const visited = new Set<string>();
|
||||
const resolve = (currentConnectorId: string): ConnectorTerminalTargetEdge[] => {
|
||||
if (visited.has(currentConnectorId)) {
|
||||
return [];
|
||||
}
|
||||
visited.add(currentConnectorId);
|
||||
|
||||
return edges.flatMap((edge) => {
|
||||
if (
|
||||
edge.type !== 'default' ||
|
||||
edge.source !== currentConnectorId ||
|
||||
edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE ||
|
||||
typeof edge.targetHandle !== 'string'
|
||||
) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === edge.target);
|
||||
if (targetNode && isConnectorNode(targetNode)) {
|
||||
return resolve(targetNode.id);
|
||||
}
|
||||
|
||||
return [edge as ConnectorTerminalTargetEdge];
|
||||
});
|
||||
};
|
||||
|
||||
return resolve(connectorId);
|
||||
};
|
||||
|
||||
const getConnectorSubgraphEdgeIds = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => {
|
||||
const visited = new Set<string>();
|
||||
const edgeIds = new Set<string>();
|
||||
|
||||
const visit = (currentConnectorId: string) => {
|
||||
if (visited.has(currentConnectorId)) {
|
||||
return;
|
||||
}
|
||||
visited.add(currentConnectorId);
|
||||
|
||||
edges.forEach((edge) => {
|
||||
if (
|
||||
edge.type !== 'default' ||
|
||||
edge.source !== currentConnectorId ||
|
||||
edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE ||
|
||||
typeof edge.targetHandle !== 'string'
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
edgeIds.add(edge.id);
|
||||
|
||||
const targetNode = nodes.find((node) => node.id === edge.target);
|
||||
if (targetNode && isConnectorNode(targetNode)) {
|
||||
visit(targetNode.id);
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
visit(connectorId);
|
||||
return edgeIds;
|
||||
};
|
||||
|
||||
const getEffectiveSource = (
|
||||
sourceId: string,
|
||||
sourceHandle: string,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[],
|
||||
templates: Templates
|
||||
): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => {
|
||||
const sourceNode = nodes.find((n) => n.id === sourceId);
|
||||
if (!sourceNode) {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
if (isConnectorNode(sourceNode)) {
|
||||
if (sourceHandle !== CONNECTOR_OUTPUT_HANDLE) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges);
|
||||
if (!resolvedSource) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return getEffectiveSource(resolvedSource.nodeId, resolvedSource.fieldName, nodes, edges, templates);
|
||||
}
|
||||
|
||||
if (!isInvocationNode(sourceNode)) {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
return {
|
||||
node: sourceNode,
|
||||
handle: sourceHandle,
|
||||
fieldTemplate: sourceFieldTemplate,
|
||||
};
|
||||
};
|
||||
|
||||
const getEffectiveSourceForEdge = (
|
||||
edge: AnyEdge,
|
||||
nodes: AnyNode[],
|
||||
edges: AnyEdge[],
|
||||
templates: Templates
|
||||
): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => {
|
||||
if (edge.type !== 'default' || typeof edge.sourceHandle !== 'string') {
|
||||
return null;
|
||||
}
|
||||
|
||||
return getEffectiveSource(edge.source, edge.sourceHandle, nodes, edges, templates);
|
||||
};
|
||||
|
||||
/**
|
||||
* Validates a connection between two fields
|
||||
* @returns A translation key for an error if the connection is invalid, otherwise null
|
||||
@@ -83,18 +226,63 @@ export const validateConnection: ValidateConnectionFunc = (
|
||||
return 'nodes.cannotDuplicateConnection';
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!sourceNode) {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
const targetNode = nodes.find((n) => n.id === c.target);
|
||||
const sourceNode = nodes.find((n) => n.id === c.source);
|
||||
if (!targetNode) {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
const sourceTemplate = templates[sourceNode.data.type];
|
||||
if (!sourceTemplate) {
|
||||
const effectiveSource = getEffectiveSource(c.source, c.sourceHandle, nodes, filteredEdges, templates);
|
||||
if (effectiveSource === 'nodes.missingNode') {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
if (effectiveSource === 'nodes.missingInvocationTemplate') {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
if (effectiveSource === 'nodes.missingFieldTemplate') {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
if (isConnectorNode(targetNode)) {
|
||||
if (c.targetHandle !== CONNECTOR_INPUT_HANDLE) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
if (filteredEdges.find(getTargetEqualityPredicate(c))) {
|
||||
return 'nodes.inputMayOnlyHaveOneConnection';
|
||||
}
|
||||
|
||||
if (effectiveSource) {
|
||||
const connectorSubgraphEdgeIds = getConnectorSubgraphEdgeIds(targetNode.id, nodes, filteredEdges);
|
||||
const stagedEdges = filteredEdges.filter((edge) => !connectorSubgraphEdgeIds.has(edge.id));
|
||||
const terminalTargetEdges = getConnectorTerminalTargetEdges(targetNode.id, nodes, filteredEdges);
|
||||
|
||||
for (const terminalTargetEdge of terminalTargetEdges) {
|
||||
const downstreamValidation = validateConnection(
|
||||
{
|
||||
source: c.source,
|
||||
sourceHandle: c.sourceHandle,
|
||||
target: terminalTargetEdge.target,
|
||||
targetHandle: terminalTargetEdge.targetHandle,
|
||||
},
|
||||
nodes,
|
||||
stagedEdges,
|
||||
templates,
|
||||
null,
|
||||
true
|
||||
);
|
||||
|
||||
if (downstreamValidation !== null) {
|
||||
return downstreamValidation;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unresolved connector chains are allowed to terminate on another connector.
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!isInvocationNode(targetNode)) {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
@@ -103,11 +291,6 @@ export const validateConnection: ValidateConnectionFunc = (
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle];
|
||||
if (!sourceFieldTemplate) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
const targetFieldTemplate = targetTemplate.inputs[c.targetHandle];
|
||||
if (!targetFieldTemplate) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
@@ -117,23 +300,58 @@ export const validateConnection: ValidateConnectionFunc = (
|
||||
return 'nodes.cannotConnectToDirectInput';
|
||||
}
|
||||
|
||||
if (!effectiveSource) {
|
||||
if (sourceNode && isConnectorNode(sourceNode) && c.sourceHandle === CONNECTOR_OUTPUT_HANDLE) {
|
||||
const existingTerminalTargetEdges = getConnectorTerminalTargetEdges(sourceNode.id, nodes, filteredEdges).filter(
|
||||
(edge) => !(edge.target === c.target && edge.targetHandle === c.targetHandle)
|
||||
);
|
||||
|
||||
for (const terminalTargetEdge of existingTerminalTargetEdges) {
|
||||
const constrainedTargetNode = nodes.find((node) => node.id === terminalTargetEdge.target);
|
||||
if (!constrainedTargetNode || !isInvocationNode(constrainedTargetNode)) {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const constrainedTargetTemplate = templates[constrainedTargetNode.data.type];
|
||||
if (!constrainedTargetTemplate) {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const constrainedTargetFieldTemplate = constrainedTargetTemplate.inputs[terminalTargetEdge.targetHandle];
|
||||
if (!constrainedTargetFieldTemplate) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
if (!areFieldTypesCompatible(constrainedTargetFieldTemplate.type, targetFieldTemplate.type)) {
|
||||
return 'nodes.fieldTypesMustMatch';
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
return 'nodes.fieldTypesMustMatch';
|
||||
}
|
||||
|
||||
const { node: resolvedSourceNode, handle: sourceHandle, fieldTemplate: sourceFieldTemplate } = effectiveSource;
|
||||
|
||||
if (targetNode.data.type === 'collect' && c.targetHandle === 'item') {
|
||||
// Collect nodes shouldn't mix and match field types.
|
||||
const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
const collectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id);
|
||||
if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) {
|
||||
return 'nodes.cannotMixAndMatchCollectionItemTypes';
|
||||
}
|
||||
}
|
||||
|
||||
if (
|
||||
sourceNode.data.type === 'collect' &&
|
||||
c.sourceHandle === 'collection' &&
|
||||
resolvedSourceNode.data.type === 'collect' &&
|
||||
sourceHandle === 'collection' &&
|
||||
targetNode.data.type === 'collect' &&
|
||||
c.targetHandle === 'collection'
|
||||
) {
|
||||
// Chained collect nodes should preserve a single item type when both ends are already typed.
|
||||
const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id);
|
||||
const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id);
|
||||
const sourceCollectItemType = getCollectItemType(templates, nodes, filteredEdges, resolvedSourceNode.id);
|
||||
const targetCollectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id);
|
||||
if (
|
||||
sourceCollectItemType &&
|
||||
targetCollectItemType &&
|
||||
@@ -148,33 +366,24 @@ export const validateConnection: ValidateConnectionFunc = (
|
||||
const siblingInputEdge = filteredEdges.find((e) => e.target === c.target && e.targetHandle === siblingHandle);
|
||||
|
||||
if (siblingInputEdge) {
|
||||
if (siblingInputEdge.source === null || siblingInputEdge.source === undefined) {
|
||||
const siblingEffectiveSource = getEffectiveSourceForEdge(siblingInputEdge, nodes, filteredEdges, templates);
|
||||
if (siblingEffectiveSource === 'nodes.missingNode') {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
if (siblingInputEdge.sourceHandle === null || siblingInputEdge.sourceHandle === undefined) {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
|
||||
const siblingSourceNode = nodes.find((n) => n.id === siblingInputEdge.source);
|
||||
if (!siblingSourceNode) {
|
||||
return 'nodes.missingNode';
|
||||
}
|
||||
|
||||
const siblingSourceTemplate = templates[siblingSourceNode.data.type];
|
||||
if (!siblingSourceTemplate) {
|
||||
if (siblingEffectiveSource === 'nodes.missingInvocationTemplate') {
|
||||
return 'nodes.missingInvocationTemplate';
|
||||
}
|
||||
|
||||
const siblingSourceFieldTemplate = siblingSourceTemplate.outputs[siblingInputEdge.sourceHandle];
|
||||
if (!siblingSourceFieldTemplate) {
|
||||
if (siblingEffectiveSource === 'nodes.missingFieldTemplate') {
|
||||
return 'nodes.missingFieldTemplate';
|
||||
}
|
||||
if (!siblingEffectiveSource) {
|
||||
return 'nodes.fieldTypesMustMatch';
|
||||
}
|
||||
|
||||
const areIfInputTypesCompatible =
|
||||
validateConnectionTypes(sourceFieldTemplate.type, siblingSourceFieldTemplate.type) ||
|
||||
validateConnectionTypes(siblingSourceFieldTemplate.type, sourceFieldTemplate.type) ||
|
||||
isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingSourceFieldTemplate.type);
|
||||
validateConnectionTypes(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type) ||
|
||||
validateConnectionTypes(siblingEffectiveSource.fieldTemplate.type, sourceFieldTemplate.type) ||
|
||||
isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type);
|
||||
|
||||
if (!areIfInputTypesCompatible) {
|
||||
return 'nodes.fieldTypesMustMatch';
|
||||
@@ -189,30 +398,17 @@ export const validateConnection: ValidateConnectionFunc = (
|
||||
}
|
||||
}
|
||||
|
||||
if (sourceNode.data.type === 'if' && c.sourceHandle === 'value') {
|
||||
if (resolvedSourceNode.data.type === 'if' && sourceHandle === 'value') {
|
||||
const ifInputEdges = filteredEdges.filter(
|
||||
(e) => e.target === sourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle)
|
||||
(e) =>
|
||||
e.target === resolvedSourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle)
|
||||
);
|
||||
const ifInputTypes = ifInputEdges.flatMap((edge) => {
|
||||
if (edge.source === null || edge.source === undefined) {
|
||||
const ifInputSource = getEffectiveSourceForEdge(edge, nodes, filteredEdges, templates);
|
||||
if (!ifInputSource || typeof ifInputSource === 'string') {
|
||||
return [];
|
||||
}
|
||||
if (edge.sourceHandle === null || edge.sourceHandle === undefined) {
|
||||
return [];
|
||||
}
|
||||
const ifInputSourceNode = nodes.find((n) => n.id === edge.source);
|
||||
if (!ifInputSourceNode) {
|
||||
return [];
|
||||
}
|
||||
const ifInputSourceTemplate = templates[ifInputSourceNode.data.type];
|
||||
if (!ifInputSourceTemplate) {
|
||||
return [];
|
||||
}
|
||||
const ifInputSourceFieldTemplate = ifInputSourceTemplate.outputs[edge.sourceHandle];
|
||||
if (!ifInputSourceFieldTemplate) {
|
||||
return [];
|
||||
}
|
||||
return [ifInputSourceFieldTemplate.type];
|
||||
return [ifInputSource.fieldTemplate.type];
|
||||
});
|
||||
|
||||
if (ifInputTypes.length > 0) {
|
||||
|
||||
@@ -43,6 +43,12 @@ export const zNotesNodeData = z.object({
|
||||
isOpen: z.boolean(),
|
||||
notes: z.string(),
|
||||
});
|
||||
export const zConnectorNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('connector'),
|
||||
label: z.string(),
|
||||
isOpen: z.boolean(),
|
||||
});
|
||||
const zCurrentImageNodeData = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('current_image'),
|
||||
@@ -52,6 +58,7 @@ const zCurrentImageNodeData = z.object({
|
||||
|
||||
export type NotesNodeData = z.infer<typeof zNotesNodeData>;
|
||||
export type InvocationNodeData = z.infer<typeof zInvocationNodeData>;
|
||||
export type ConnectorNodeData = z.infer<typeof zConnectorNodeData>;
|
||||
type CurrentImageNodeData = z.infer<typeof zCurrentImageNodeData>;
|
||||
|
||||
const zInvocationNodeValidationSchema = z.looseObject({
|
||||
@@ -70,6 +77,15 @@ const zNotesNodeValidationSchema = z.looseObject({
|
||||
const zNotesNode = z.custom<Node<NotesNodeData, 'notes'>>((val) => zNotesNodeValidationSchema.safeParse(val).success);
|
||||
export type NotesNode = z.infer<typeof zNotesNode>;
|
||||
|
||||
const zConnectorNodeValidationSchema = z.looseObject({
|
||||
type: z.literal('connector'),
|
||||
data: zConnectorNodeData,
|
||||
});
|
||||
const zConnectorNode = z.custom<Node<ConnectorNodeData, 'connector'>>(
|
||||
(val) => zConnectorNodeValidationSchema.safeParse(val).success
|
||||
);
|
||||
export type ConnectorNode = z.infer<typeof zConnectorNode>;
|
||||
|
||||
const zCurrentImageNodeValidationSchema = z.looseObject({
|
||||
type: z.literal('current_image'),
|
||||
data: zCurrentImageNodeData,
|
||||
@@ -79,12 +95,14 @@ const zCurrentImageNode = z.custom<Node<CurrentImageNodeData, 'current_image'>>(
|
||||
);
|
||||
export type CurrentImageNode = z.infer<typeof zCurrentImageNode>;
|
||||
|
||||
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]);
|
||||
export const zAnyNode = z.union([zInvocationNode, zNotesNode, zConnectorNode, zCurrentImageNode]);
|
||||
export type AnyNode = z.infer<typeof zAnyNode>;
|
||||
|
||||
export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode =>
|
||||
Boolean(node && node.type === 'invocation');
|
||||
export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes');
|
||||
export const isConnectorNode = (node?: AnyNode | null): node is ConnectorNode =>
|
||||
Boolean(node && node.type === 'connector');
|
||||
// #endregion
|
||||
|
||||
// #region NodeExecutionState
|
||||
|
||||
@@ -3,7 +3,7 @@ import { z } from 'zod';
|
||||
|
||||
import type { FieldType } from './field';
|
||||
import { zFieldIdentifier } from './field';
|
||||
import { zInvocationNodeData, zNotesNodeData } from './invocation';
|
||||
import { zConnectorNodeData, zInvocationNodeData, zNotesNodeData } from './invocation';
|
||||
|
||||
// #region Workflow misc
|
||||
const zXYPosition = z
|
||||
@@ -31,7 +31,13 @@ const zWorkflowNotesNode = z.object({
|
||||
data: zNotesNodeData,
|
||||
position: zXYPosition,
|
||||
});
|
||||
const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]);
|
||||
const zWorkflowConnectorNode = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
type: z.literal('connector'),
|
||||
data: zConnectorNodeData,
|
||||
position: zXYPosition,
|
||||
});
|
||||
const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode, zWorkflowConnectorNode]);
|
||||
|
||||
type WorkflowInvocationNode = z.infer<typeof zWorkflowInvocationNode>;
|
||||
|
||||
@@ -377,7 +383,7 @@ export const zWorkflowV3 = z.object({
|
||||
exposedFields: z.array(zFieldIdentifier),
|
||||
meta: z.object({
|
||||
category: zWorkflowCategory.default('user'),
|
||||
version: z.literal('3.0.0'),
|
||||
version: z.literal('4.0.0'),
|
||||
}),
|
||||
// Use the validated form schema!
|
||||
form: zValidatedBuilderForm,
|
||||
|
||||
@@ -0,0 +1,194 @@
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
|
||||
import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { buildNodesGraph } from './buildNodesGraph';
|
||||
|
||||
const buildConnectorNode = (id: string) => ({
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
label: 'Connector',
|
||||
isOpen: true,
|
||||
},
|
||||
});
|
||||
|
||||
const buildState = (nodes: unknown[], edges: unknown[]) =>
|
||||
({
|
||||
nodes: {
|
||||
present: {
|
||||
_version: 1,
|
||||
nodes,
|
||||
edges,
|
||||
formFieldInitialValues: {},
|
||||
id: undefined,
|
||||
name: '',
|
||||
author: '',
|
||||
description: '',
|
||||
version: '',
|
||||
contact: '',
|
||||
tags: '',
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
meta: { version: '4.0.0', category: 'user' },
|
||||
form: {
|
||||
rootElementId: 'root',
|
||||
elements: {
|
||||
root: {
|
||||
id: 'root',
|
||||
type: 'container',
|
||||
data: { layout: 'column', children: [] },
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
gallery: {
|
||||
autoAddBoardId: 'none',
|
||||
selection: [],
|
||||
},
|
||||
}) as unknown as Parameters<typeof buildNodesGraph>[0];
|
||||
|
||||
describe('buildNodesGraph', () => {
|
||||
it('flattens a single connector to one direct execution edge', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const state = buildState(
|
||||
[source, target, connector],
|
||||
[
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
]
|
||||
);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.nodes).not.toHaveProperty(connector.id);
|
||||
expect(graph.edges).toEqual([
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: target.id, field: 'a' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('flattens chained connectors transitively', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connectorA = buildConnectorNode('connector-a');
|
||||
const connectorB = buildConnectorNode('connector-b');
|
||||
const state = buildState(
|
||||
[source, target, connectorA, connectorB],
|
||||
[
|
||||
buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
]
|
||||
);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.edges).toEqual([
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: target.id, field: 'a' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('fans out through a connector into multiple execution edges', () => {
|
||||
const source = buildNode(add);
|
||||
const targetA = buildNode(sub);
|
||||
const targetB = buildNode(img_resize);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const state = buildState(
|
||||
[source, targetA, targetB, connector],
|
||||
[
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'),
|
||||
]
|
||||
);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.edges).toEqual([
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: targetA.id, field: 'a' },
|
||||
},
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: targetB.id, field: 'width' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('drops unresolved connector paths from the execution graph', () => {
|
||||
const target = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const state = buildState([target, connector], [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.nodes).not.toHaveProperty(connector.id);
|
||||
expect(graph.edges).toEqual([]);
|
||||
});
|
||||
|
||||
it('deduplicates effective execution edges created by flattening', () => {
|
||||
const source = buildNode(add);
|
||||
const target = buildNode(sub);
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const state = buildState(
|
||||
[source, target, connector],
|
||||
[
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
buildEdge(source.id, 'value', target.id, 'a'),
|
||||
]
|
||||
);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.edges).toEqual([
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: target.id, field: 'a' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('still omits explicit destination input values when the flattened edge exists', () => {
|
||||
const source = buildNode(add);
|
||||
const target = deepClone(buildNode(sub));
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
const inputA = target.data.inputs.a;
|
||||
expect(inputA).toBeDefined();
|
||||
if (!inputA) {
|
||||
throw new Error('Missing input a');
|
||||
}
|
||||
inputA.value = 'not-an-integer' as never;
|
||||
const state = buildState(
|
||||
[source, target, connector],
|
||||
[
|
||||
buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE),
|
||||
buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'),
|
||||
]
|
||||
);
|
||||
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
|
||||
expect(graph.edges).toEqual([
|
||||
{
|
||||
source: { node_id: source.id, field: 'value' },
|
||||
destination: { node_id: target.id, field: 'a' },
|
||||
},
|
||||
]);
|
||||
expect(graph.nodes[target.id]).not.toHaveProperty('a');
|
||||
});
|
||||
});
|
||||
@@ -4,10 +4,11 @@ import { omit, reduce } from 'es-toolkit/compat';
|
||||
import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopology';
|
||||
import type { BoardField } from 'features/nodes/types/common';
|
||||
import type { BoardFieldInputInstance } from 'features/nodes/types/field';
|
||||
import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field';
|
||||
import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
@@ -96,12 +97,57 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require
|
||||
const filteredNodeIds = filteredNodes.map(({ id }) => id);
|
||||
|
||||
// skip out the "dummy" edges between collapsed nodes
|
||||
const filteredEdges = edges
|
||||
.filter((edge) => edge.type !== 'collapsed')
|
||||
.filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target));
|
||||
const flattenedEdges = edges
|
||||
.filter((edge) => edge.type === 'default')
|
||||
.flatMap((edge) => {
|
||||
const targetNode = nodes.find((node) => node.id === edge.target);
|
||||
if (!targetNode || !isInvocationNode(targetNode) || !isExecutableNode(targetNode)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const sourceNode = nodes.find((node) => node.id === edge.source);
|
||||
if (!sourceNode) {
|
||||
return [];
|
||||
}
|
||||
|
||||
if (isInvocationNode(sourceNode)) {
|
||||
if (!isExecutableNode(sourceNode) || !filteredNodeIds.includes(sourceNode.id)) {
|
||||
return [];
|
||||
}
|
||||
return [edge];
|
||||
}
|
||||
|
||||
if (isConnectorNode(sourceNode)) {
|
||||
const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges);
|
||||
if (!resolvedSource || !filteredNodeIds.includes(resolvedSource.nodeId)) {
|
||||
return [];
|
||||
}
|
||||
return [
|
||||
{
|
||||
...edge,
|
||||
id: `flattened-${resolvedSource.nodeId}-${resolvedSource.fieldName}-${edge.target}-${edge.targetHandle}`,
|
||||
source: resolvedSource.nodeId,
|
||||
sourceHandle: resolvedSource.fieldName,
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
return [];
|
||||
})
|
||||
.filter((edge, index, allEdges) => {
|
||||
return (
|
||||
allEdges.findIndex(
|
||||
(candidate) =>
|
||||
candidate.source === edge.source &&
|
||||
candidate.sourceHandle === edge.sourceHandle &&
|
||||
candidate.target === edge.target &&
|
||||
candidate.targetHandle === edge.targetHandle
|
||||
) === index
|
||||
);
|
||||
});
|
||||
|
||||
// Reduce the node editor edges into invocation graph edges
|
||||
const parsedEdges = filteredEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const parsedEdges = flattenedEdges.reduce<NonNullable<Graph['edges']>>((edgesAccumulator, edge) => {
|
||||
const { source, target, sourceHandle, targetHandle } = edge;
|
||||
|
||||
if (!sourceHandle || !targetHandle) {
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import type { XYPosition } from '@xyflow/react';
|
||||
import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
|
||||
import type { ConnectorNode } from 'features/nodes/types/invocation';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
export const buildConnectorNode = (position: XYPosition): ConnectorNode => {
|
||||
const nodeId = uuidv4();
|
||||
const node: ConnectorNode = {
|
||||
...SHARED_NODE_PROPERTIES,
|
||||
id: nodeId,
|
||||
type: 'connector',
|
||||
position,
|
||||
data: {
|
||||
id: nodeId,
|
||||
type: 'connector',
|
||||
isOpen: true,
|
||||
label: 'Connector',
|
||||
},
|
||||
};
|
||||
return node;
|
||||
};
|
||||
@@ -5,7 +5,7 @@ import { parseify } from 'common/util/serialize';
|
||||
import { pick } from 'es-toolkit/compat';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { zWorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import i18n from 'i18n';
|
||||
@@ -42,6 +42,9 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => {
|
||||
if (isInvocationNode(node) && node.type) {
|
||||
const { id, type, data, position } = node;
|
||||
newWorkflow.nodes.push({ id, type, data, position });
|
||||
} else if (isConnectorNode(node) && node.type) {
|
||||
const { id, type, data, position } = node;
|
||||
newWorkflow.nodes.push({ id, type, data, position });
|
||||
} else if (isNotesNode(node) && node.type) {
|
||||
const { id, type, data, position } = node;
|
||||
newWorkflow.nodes.push({ id, type, data, position });
|
||||
|
||||
@@ -30,7 +30,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor
|
||||
description: '',
|
||||
meta: {
|
||||
category: 'user',
|
||||
version: '3.0.0',
|
||||
version: '4.0.0',
|
||||
},
|
||||
notes: '',
|
||||
tags: '',
|
||||
|
||||
@@ -70,9 +70,11 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => {
|
||||
};
|
||||
|
||||
const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => {
|
||||
// Bump version
|
||||
(workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0';
|
||||
// Parsing strips out any extra properties not in the latest version
|
||||
return migrateV3toV4(workflowToMigrate as unknown as WorkflowV3);
|
||||
};
|
||||
|
||||
const migrateV3toV4 = (workflowToMigrate: WorkflowV3): WorkflowV3 => {
|
||||
workflowToMigrate.meta.version = '4.0.0';
|
||||
return zWorkflowV3.parse(workflowToMigrate);
|
||||
};
|
||||
|
||||
@@ -100,6 +102,10 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => {
|
||||
workflow = migrateV2toV3(v2);
|
||||
}
|
||||
|
||||
if (get(workflow, 'meta.version') === '3.0.0') {
|
||||
workflow = migrateV3toV4(workflow as WorkflowV3);
|
||||
}
|
||||
|
||||
// We should now have a V3 workflow
|
||||
const migratedWorkflow = zWorkflowV3.parse(workflow);
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { get } from 'es-toolkit/compat';
|
||||
import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology';
|
||||
import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils';
|
||||
import type { WorkflowV3 } from 'features/nodes/types/workflow';
|
||||
import { getDefaultForm } from 'features/nodes/types/workflow';
|
||||
@@ -7,6 +8,17 @@ import { describe, expect, it } from 'vitest';
|
||||
|
||||
//TODO(psyche): Test workflow validation for form builder fields
|
||||
describe('validateWorkflow', () => {
|
||||
const buildConnectorNode = (id: string) => ({
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
position: { x: 0, y: 0 },
|
||||
data: {
|
||||
id,
|
||||
type: 'connector' as const,
|
||||
label: 'Connector',
|
||||
isOpen: true,
|
||||
},
|
||||
});
|
||||
const getWorkflow = (): WorkflowV3 => ({
|
||||
name: '',
|
||||
author: '',
|
||||
@@ -17,7 +29,7 @@ describe('validateWorkflow', () => {
|
||||
notes: '',
|
||||
exposedFields: [],
|
||||
form: getDefaultForm(),
|
||||
meta: { version: '3.0.0', category: 'user' },
|
||||
meta: { version: '4.0.0', category: 'user' },
|
||||
nodes: [
|
||||
{
|
||||
id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad',
|
||||
@@ -104,6 +116,7 @@ describe('validateWorkflow', () => {
|
||||
});
|
||||
expect(validationResult.warnings.length).toBe(1);
|
||||
expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined();
|
||||
expect(validationResult.workflow.meta.version).toBe('4.0.0');
|
||||
});
|
||||
it('should reset boards that are inaccessible', async () => {
|
||||
const validationResult = await validateWorkflow({
|
||||
@@ -127,4 +140,138 @@ describe('validateWorkflow', () => {
|
||||
expect(validationResult.warnings.length).toBe(1);
|
||||
expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should delete malformed connector edges with invalid handles', async () => {
|
||||
const workflow = getWorkflow();
|
||||
workflow.nodes.push(buildConnectorNode('connector-1'));
|
||||
workflow.edges.push({
|
||||
id: 'e1',
|
||||
type: 'default',
|
||||
source: workflow.nodes[0]!.id,
|
||||
sourceHandle: 'vae',
|
||||
target: 'connector-1',
|
||||
targetHandle: 'wrong',
|
||||
});
|
||||
|
||||
const validationResult = await validateWorkflow({
|
||||
workflow,
|
||||
templates: { img_resize, main_model_loader },
|
||||
checkImageAccess: resolveTrue,
|
||||
checkBoardAccess: resolveTrue,
|
||||
checkModelAccess: resolveTrue,
|
||||
});
|
||||
|
||||
expect(validationResult.workflow.edges).toEqual([]);
|
||||
expect(validationResult.warnings.length).toBe(1);
|
||||
});
|
||||
|
||||
it('should delete connector edges with missing endpoints', async () => {
|
||||
const workflow = getWorkflow();
|
||||
workflow.nodes.push(buildConnectorNode('connector-1'));
|
||||
workflow.edges.push({
|
||||
id: 'e1',
|
||||
type: 'default',
|
||||
source: 'missing-node',
|
||||
sourceHandle: 'value',
|
||||
target: 'connector-1',
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
});
|
||||
|
||||
const validationResult = await validateWorkflow({
|
||||
workflow,
|
||||
templates: { img_resize, main_model_loader },
|
||||
checkImageAccess: resolveTrue,
|
||||
checkBoardAccess: resolveTrue,
|
||||
checkModelAccess: resolveTrue,
|
||||
});
|
||||
|
||||
expect(validationResult.workflow.edges).toEqual([]);
|
||||
expect(validationResult.warnings.length).toBe(1);
|
||||
});
|
||||
|
||||
it('should repair invalid multi-input connector state predictably by keeping the first valid input edge', async () => {
|
||||
const workflow = getWorkflow();
|
||||
const loader2 = structuredClone(workflow.nodes[0]!);
|
||||
loader2.id = 'second-loader';
|
||||
loader2.data.id = 'second-loader';
|
||||
const connector = buildConnectorNode('connector-1');
|
||||
workflow.nodes.push(loader2, connector);
|
||||
workflow.edges.push({
|
||||
id: 'e1',
|
||||
type: 'default',
|
||||
source: workflow.nodes[0]!.id,
|
||||
sourceHandle: 'vae',
|
||||
target: connector.id,
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
});
|
||||
workflow.edges.push({
|
||||
id: 'e2',
|
||||
type: 'default',
|
||||
source: loader2.id,
|
||||
sourceHandle: 'vae',
|
||||
target: connector.id,
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
});
|
||||
|
||||
const validationResult = await validateWorkflow({
|
||||
workflow,
|
||||
templates: { img_resize, main_model_loader },
|
||||
checkImageAccess: resolveTrue,
|
||||
checkBoardAccess: resolveTrue,
|
||||
checkModelAccess: resolveTrue,
|
||||
});
|
||||
|
||||
expect(validationResult.workflow.edges).toEqual([
|
||||
{
|
||||
id: 'e1',
|
||||
type: 'default',
|
||||
source: workflow.nodes[0]!.id,
|
||||
sourceHandle: 'vae',
|
||||
target: connector.id,
|
||||
targetHandle: CONNECTOR_INPUT_HANDLE,
|
||||
},
|
||||
]);
|
||||
expect(validationResult.warnings.length).toBe(1);
|
||||
});
|
||||
|
||||
it('should retain isolated connectors during workflow validation', async () => {
|
||||
const workflow = getWorkflow();
|
||||
workflow.nodes.push(buildConnectorNode('connector-1'));
|
||||
|
||||
const validationResult = await validateWorkflow({
|
||||
workflow,
|
||||
templates: { img_resize, main_model_loader },
|
||||
checkImageAccess: resolveTrue,
|
||||
checkBoardAccess: resolveTrue,
|
||||
checkModelAccess: resolveTrue,
|
||||
});
|
||||
|
||||
expect(validationResult.workflow.nodes.find((node) => node.id === 'connector-1')).toBeDefined();
|
||||
expect(validationResult.warnings).toEqual([]);
|
||||
});
|
||||
|
||||
it('should retain unresolved connector output edges that establish downstream constraints in the editor', async () => {
|
||||
const workflow = getWorkflow();
|
||||
workflow.nodes.push(buildConnectorNode('connector-1'));
|
||||
const unresolvedEdge = {
|
||||
id: 'e1',
|
||||
type: 'default' as const,
|
||||
source: 'connector-1',
|
||||
sourceHandle: CONNECTOR_OUTPUT_HANDLE,
|
||||
target: workflow.nodes[1]!.id,
|
||||
targetHandle: 'image',
|
||||
};
|
||||
workflow.edges.push(unresolvedEdge);
|
||||
|
||||
const validationResult = await validateWorkflow({
|
||||
workflow,
|
||||
templates: { img_resize, main_model_loader },
|
||||
checkImageAccess: resolveTrue,
|
||||
checkBoardAccess: resolveTrue,
|
||||
checkModelAccess: resolveTrue,
|
||||
});
|
||||
|
||||
expect(validationResult.workflow.edges).toEqual([unresolvedEdge]);
|
||||
expect(validationResult.warnings).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { validateConnection } from 'features/nodes/store/util/validateConnection';
|
||||
import {
|
||||
isBoardFieldInputInstance,
|
||||
isImageFieldCollectionInputInstance,
|
||||
@@ -149,8 +150,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
}
|
||||
}
|
||||
|
||||
// Stash invalid edges here to be deleted later
|
||||
const edgesToDelete = new Set<string>();
|
||||
const validEdges = [];
|
||||
|
||||
for (const edge of edges) {
|
||||
// Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow.
|
||||
@@ -169,7 +169,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
);
|
||||
}
|
||||
|
||||
if (!sourceTemplate) {
|
||||
if (sourceNode?.type === 'invocation' && !sourceTemplate) {
|
||||
// The edge's source/output node template does not exist
|
||||
issues.push(
|
||||
t('nodes.missingTemplate', {
|
||||
@@ -198,7 +198,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
);
|
||||
}
|
||||
|
||||
if (!targetTemplate) {
|
||||
if (targetNode?.type === 'invocation' && !targetTemplate) {
|
||||
// The edge's target/input node template does not exist
|
||||
issues.push(
|
||||
t('nodes.missingTemplate', {
|
||||
@@ -218,8 +218,14 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
);
|
||||
}
|
||||
|
||||
if (!issues.length && edge.type === 'default') {
|
||||
const connectionError = validateConnection(edge, nodes, validEdges, templates, null, true);
|
||||
if (connectionError) {
|
||||
issues.push(connectionError);
|
||||
}
|
||||
}
|
||||
|
||||
if (issues.length) {
|
||||
edgesToDelete.add(edge.id);
|
||||
const source = edge.type === 'default' ? `${edge.source}.${edge.sourceHandle}` : edge.source;
|
||||
const target = edge.type === 'default' ? `${edge.source}.${edge.targetHandle}` : edge.target;
|
||||
warnings.push({
|
||||
@@ -227,11 +233,13 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise<Vali
|
||||
issues,
|
||||
data: edge,
|
||||
});
|
||||
} else {
|
||||
validEdges.push(edge);
|
||||
}
|
||||
}
|
||||
|
||||
// Remove invalid edges
|
||||
_workflow.edges = edges.filter(({ id }) => !edgesToDelete.has(id));
|
||||
_workflow.edges = validEdges;
|
||||
|
||||
// Migrated exposed fields to form elements if they exist and the form does not
|
||||
// Note: If the form is invalid per its zod schema, it will be reset to a default, empty form!
|
||||
|
||||
Reference in New Issue
Block a user