From dd056067a9ac23b530ab9f6ba8c410cf252b092d Mon Sep 17 00:00:00 2001 From: Jonathan <34005131+JPPhoto@users.noreply.github.com> Date: Mon, 13 Apr 2026 19:03:59 -0500 Subject: [PATCH 1/6] Added workflow connectors (#9027) * Add persisted workflow connectors * Polish workflow connector menu and visuals * Refine connector sizing and alignment * Fix connector deletion and unresolved constraints * Format connector deletion tests * Revert frontend test config tweaks * Address workflow connector review feedback * Format workflow flow menu memo --- invokeai/frontend/web/public/locales/en.json | 2 + .../features/nodes/components/flow/Flow.tsx | 347 +++++++++++++++++- .../flow/edges/util/buildEdgeSelectors.ts | 28 +- .../flow/nodes/Connector/ConnectorNode.tsx | 97 +++++ .../nodes/common/NonInvocationNodeWrapper.tsx | 11 +- .../components/flow/nodes/common/shared.ts | 8 +- .../src/features/nodes/hooks/useBuildNode.ts | 7 +- .../src/features/nodes/hooks/useConnection.ts | 47 +++ .../features/nodes/store/nodesSlice.test.ts | 160 ++++++++ .../src/features/nodes/store/nodesSlice.ts | 78 +++- .../store/util/connectorTopology.test.ts | 126 +++++++ .../nodes/store/util/connectorTopology.ts | 228 ++++++++++++ .../util/getFirstValidConnection.test.ts | 112 +++++- .../store/util/getFirstValidConnection.ts | 81 +++- .../nodes/store/util/reactFlowUtil.test.ts | 23 ++ .../nodes/store/util/reactFlowUtil.ts | 1 + .../store/util/validateConnection.test.ts | 225 ++++++++++++ .../nodes/store/util/validateConnection.ts | 316 +++++++++++++--- .../src/features/nodes/types/invocation.ts | 20 +- .../web/src/features/nodes/types/workflow.ts | 12 +- .../nodes/util/graph/buildNodesGraph.test.ts | 194 ++++++++++ .../nodes/util/graph/buildNodesGraph.ts | 56 ++- .../nodes/util/node/buildConnectorNode.ts | 21 ++ .../nodes/util/workflow/buildWorkflow.ts | 5 +- .../nodes/util/workflow/graphToWorkflow.ts | 2 +- .../nodes/util/workflow/migrations.ts | 12 +- .../util/workflow/validateWorkflow.test.ts | 149 +++++++- .../nodes/util/workflow/validateWorkflow.ts | 20 +- 28 files changed, 2260 insertions(+), 128 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx create mode 100644 invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts create mode 100644 invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts create mode 100644 invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 201ea8badb..7fad98531e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1417,6 +1417,8 @@ "notes": "Notes", "description": "Description", "notesDescription": "Add notes about your workflow", + "addConnector": "Add Connector", + "deleteConnector": "Delete Connector", "problemSettingTitle": "Problem Setting Title", "resetToDefaultValue": "Reset to default value", "reloadNodeTemplates": "Reload Node Templates", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx index 0c48eddfc6..2fc5e12384 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/Flow.tsx @@ -1,4 +1,4 @@ -import { useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; +import { Menu, MenuButton, MenuItem, MenuList, Portal, useGlobalMenuClose, useToken } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; import type { EdgeChange, @@ -32,7 +32,9 @@ import { $edgePendingUpdate, $lastEdgeUpdateMouseEvent, $pendingConnection, + $templates, $viewport, + connectorInserted, edgesChanged, nodesChanged, redo, @@ -46,18 +48,24 @@ import { selectNodes, selectNodesSlice, } from 'features/nodes/store/selectors'; +import { getConnectorDeletionSpliceConnections } from 'features/nodes/store/util/connectorTopology'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { selectSelectionMode, selectShouldSnapToGrid } from 'features/nodes/store/workflowSettingsSlice'; import { NO_DRAG_CLASS, NO_PAN_CLASS, NO_WHEEL_CLASS } from 'features/nodes/types/constants'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import type { CSSProperties, MouseEvent, RefObject } from 'react'; -import { memo, useCallback, useMemo, useRef } from 'react'; +import { memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { useHotkeys } from 'react-hotkeys-hook'; +import { useTranslation } from 'react-i18next'; +import { PiPlugsConnectedBold, PiTrashBold } from 'react-icons/pi'; import CustomConnectionLine from './connectionLines/CustomConnectionLine'; import InvocationCollapsedEdge from './edges/InvocationCollapsedEdge'; import InvocationDefaultEdge from './edges/InvocationDefaultEdge'; +import ConnectorNode from './nodes/Connector/ConnectorNode'; import CurrentImageNode from './nodes/CurrentImage/CurrentImageNode'; import InvocationNodeWrapper from './nodes/Invocation/InvocationNodeWrapper'; import NotesNode from './nodes/Notes/NotesNode'; @@ -70,6 +78,7 @@ const edgeTypes = { const nodeTypes = { invocation: InvocationNodeWrapper, + connector: ConnectorNode, current_image: CurrentImageNode, notes: NotesNode, } as const; @@ -81,17 +90,69 @@ const snapGrid: [number, number] = [25, 25]; const selectCancelConnection = (state: ReactFlowState) => state.cancelConnection; +type WorkflowContextMenuState = + | { + kind: 'pane'; + clientX: number; + clientY: number; + pageX: number; + pageY: number; + } + | { + kind: 'connector'; + connectorId: string; + pageX: number; + pageY: number; + } + | null; + +const getWorkflowContextMenuState = ( + event: globalThis.MouseEvent, + flowWrapper: HTMLDivElement | null +): WorkflowContextMenuState => { + if (event.shiftKey || !(event.target instanceof Element) || !flowWrapper?.contains(event.target)) { + return null; + } + + const connectorId = event.target.closest('[data-connector-node-id]')?.dataset.connectorNodeId; + if (connectorId) { + return { + kind: 'connector', + connectorId, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + const paneTarget = event.target.closest('.react-flow__pane'); + if (paneTarget && flowWrapper.contains(paneTarget)) { + return { + kind: 'pane', + clientX: event.clientX, + clientY: event.clientY, + pageX: event.pageX, + pageY: event.pageY, + }; + } + + return null; +}; + export const Flow = memo(() => { + const { t } = useTranslation(); const dispatch = useAppDispatch(); const nodes = useAppSelector(selectNodes); const edges = useAppSelector(selectEdges); + const templates = useStore($templates); const viewport = useStore($viewport); const shouldSnapToGrid = useAppSelector(selectShouldSnapToGrid); const selectionMode = useAppSelector(selectSelectionMode); const { onConnectStart, onConnect, onConnectEnd } = useConnection(); const flowWrapper = useRef(null); + const pendingNodeInternalsUpdateRef = useRef(null); const isValidConnection = useIsValidConnection(); const updateNodeInternals = useUpdateNodeInternals(); + const [contextMenuState, setContextMenuState] = useState(null); useFocusRegion('workflows', flowWrapper); @@ -127,9 +188,16 @@ export const Flow = memo(() => { }, []); const { onCloseGlobal } = useGlobalMenuClose(); - const handlePaneClick = useCallback(() => { - onCloseGlobal(); - }, [onCloseGlobal]); + const handlePaneClick: NonNullable['onPaneClick']> = useCallback( + (event) => { + if ('button' in event && event.button !== 0) { + return; + } + onCloseGlobal(); + setContextMenuState(null); + }, + [onCloseGlobal] + ); const onInit: OnInit = useCallback((flow) => { $flow.set(flow); @@ -147,6 +215,149 @@ export const Flow = memo(() => { } }, []); + useEffect(() => { + const pendingNodeIds = pendingNodeInternalsUpdateRef.current; + if (!pendingNodeIds) { + return; + } + + pendingNodeInternalsUpdateRef.current = null; + + const frameId = requestAnimationFrame(() => { + updateNodeInternals([...new Set(pendingNodeIds)]); + }); + + return () => { + cancelAnimationFrame(frameId); + }; + }, [edges, nodes, updateNodeInternals]); + + const addConnectorAtPaneMenuPosition = useCallback(() => { + if (contextMenuState?.kind !== 'pane') { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: contextMenuState.clientX, + y: contextMenuState.clientY, + }) + ); + dispatch(nodesChanged([{ type: 'add', item: connector }])); + setContextMenuState(null); + }, [contextMenuState, dispatch]); + + const connectorSpliceConnections = useMemo( + () => + contextMenuState?.kind === 'connector' + ? getConnectorDeletionSpliceConnections( + contextMenuState.connectorId, + nodes, + edges, + templates, + validateConnection + ) + : null, + [contextMenuState, edges, nodes, templates] + ); + + const deleteConnectorFromContextMenu = useCallback(() => { + if (contextMenuState?.kind !== 'connector' || !connectorSpliceConnections) { + return; + } + const connectorEdgeRemovals: EdgeChange[] = edges + .filter((edge) => edge.source === contextMenuState.connectorId || edge.target === contextMenuState.connectorId) + .map((edge) => ({ type: 'remove', id: edge.id })); + const spliceEdgeAdditions: EdgeChange[] = connectorSpliceConnections.map((connection) => ({ + type: 'add', + item: connectionToEdge(connection), + })); + + pendingNodeInternalsUpdateRef.current = [ + contextMenuState.connectorId, + ...connectorSpliceConnections.flatMap((connection) => [connection.source, connection.target]), + ]; + dispatch(edgesChanged([...connectorEdgeRemovals, ...spliceEdgeAdditions])); + dispatch(nodesChanged([{ type: 'remove', id: contextMenuState.connectorId }])); + setContextMenuState(null); + }, [connectorSpliceConnections, contextMenuState, dispatch, edges]); + + useEffect(() => { + const onWindowContextMenu = (event: globalThis.MouseEvent) => { + const nextContextMenuState = getWorkflowContextMenuState(event, flowWrapper.current); + if (!nextContextMenuState) { + return; + } + + event.preventDefault(); + event.stopPropagation(); + setContextMenuState(nextContextMenuState); + }; + + window.addEventListener('contextmenu', onWindowContextMenu, { capture: true }); + + return () => { + window.removeEventListener('contextmenu', onWindowContextMenu, { capture: true }); + }; + }, []); + + const renderContextMenu = useCallback(() => { + if (contextMenuState?.kind === 'pane') { + return ( + + } onClick={addConnectorAtPaneMenuPosition}> + {t('nodes.addConnector')} + + + ); + } + + if (contextMenuState?.kind === 'connector') { + return ( + + } + onClick={deleteConnectorFromContextMenu} + isDisabled={!connectorSpliceConnections} + isDestructive + > + {t('nodes.deleteConnector')} + + + ); + } + + return ; + }, [addConnectorAtPaneMenuPosition, connectorSpliceConnections, contextMenuState, deleteConnectorFromContextMenu, t]); + + const closeContextMenu = useCallback(() => { + setContextMenuState(null); + }, []); + + const onEdgeDoubleClick = useCallback>( + (event, edge) => { + if (edge.type !== 'default' || edge.hidden) { + return; + } + const flow = $flow.get(); + if (!flow) { + return; + } + const connector = buildConnectorNode( + flow.screenToFlowPosition({ + x: event.clientX, + y: event.clientY, + }) + ); + dispatch(connectorInserted({ edgeId: edge.id, connector })); + updateNodeInternals([edge.source, edge.target, connector.id]); + }, + [dispatch, updateNodeInternals] + ); + // #region Updatable Edges /** @@ -209,16 +420,130 @@ export const Flow = memo(() => { // #endregion + const renderedNodes = useMemo(() => nodes, [nodes]); + + const renderedEdges = useMemo(() => edges, [edges]); + const contextMenuPosition = contextMenuState ? { x: contextMenuState.pageX, y: contextMenuState.pageY } : null; + const contextMenuKey = contextMenuPosition ? `${contextMenuPosition.x}-${contextMenuPosition.y}` : 'closed'; + return ( <> + + + + + {renderContextMenu()} + + + + + ); +}); + +Flow.displayName = 'Flow'; + +type FlowSurfaceProps = { + flowWrapper: { current: HTMLDivElement | null }; + viewport: ReactFlowProps['defaultViewport']; + renderedNodes: AnyNode[]; + renderedEdges: AnyEdge[]; + onInit: OnInit; + onMouseMove: (event: MouseEvent) => void; + onNodesChange: OnNodesChange; + onEdgesChange: OnEdgesChange; + onReconnect: OnReconnect; + onReconnectStart: NonNullable['onReconnectStart']>; + onReconnectEnd: NonNullable['onReconnectEnd']>; + onConnectStart: NonNullable['onConnectStart']>; + onConnect: NonNullable['onConnect']>; + onConnectEnd: NonNullable['onConnectEnd']>; + handleMoveEnd: OnMoveEnd; + onEdgeDoubleClick: NonNullable['onEdgeDoubleClick']>; + isValidConnection: NonNullable['isValidConnection']>; + shouldSnapToGrid: boolean; + flowStyles: CSSProperties; + handlePaneClick: NonNullable['onPaneClick']>; + selectionMode: ReturnType; +}; + +const FlowSurface = memo((props: FlowSurfaceProps) => { + const { + flowWrapper, + viewport, + renderedNodes, + renderedEdges, + onInit, + onMouseMove, + onNodesChange, + onEdgesChange, + onReconnect, + onReconnectStart, + onReconnectEnd, + onConnectStart, + onConnect, + onConnectEnd, + handleMoveEnd, + onEdgeDoubleClick, + isValidConnection, + shouldSnapToGrid, + flowStyles, + handlePaneClick, + selectionMode, + } = props; + + const setFlowWrapperElement = useCallback( + (el: HTMLDivElement | null) => { + flowWrapper.current = el; + }, + [flowWrapper] + ); + + return ( +
id="workflow-editor" - ref={flowWrapper} defaultViewport={viewport} nodeTypes={nodeTypes} edgeTypes={edgeTypes} - nodes={nodes} - edges={edges} + nodes={renderedNodes} + edges={renderedEdges} onInit={onInit} onMouseMove={onMouseMove} onNodesChange={onNodesChange} @@ -230,6 +555,7 @@ export const Flow = memo(() => { onConnect={onConnect} onConnectEnd={onConnectEnd} onMoveEnd={handleMoveEnd} + onEdgeDoubleClick={onEdgeDoubleClick} connectionLineComponent={CustomConnectionLine} isValidConnection={isValidConnection} minZoom={0.1} @@ -249,12 +575,11 @@ export const Flow = memo(() => { > - - +
); }); -Flow.displayName = 'Flow'; +FlowSurface.displayName = 'FlowSurface'; const HotkeyIsolator = memo(({ flowWrapper }: { flowWrapper: RefObject }) => { const mayUndo = useAppSelector(selectMayUndo); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts index b64e5a6e6a..40a44a16c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/edges/util/buildEdgeSelectors.ts @@ -1,9 +1,10 @@ import { createSelector } from '@reduxjs/toolkit'; import { colorTokenToCssVar } from 'common/util/colorTokenToCssVar'; -import { selectNodes } from 'features/nodes/store/selectors'; +import { selectEdges, selectNodes } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource, resolveConnectorSourceFieldType } from 'features/nodes/store/util/connectorTopology'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; -import { isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { getFieldColor } from './getEdgeColor'; @@ -22,23 +23,20 @@ export const buildSelectEdgeColor = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, selectWorkflowSettingsSlice, (nodes, workflowSettings): string => { + createSelector(selectNodes, selectEdges, selectWorkflowSettingsSlice, (nodes, edges, workflowSettings): string => { const { shouldColorEdges } = workflowSettings; if (!shouldColorEdges) { return colorTokenToCssVar('base.500'); } const sourceNode = nodes.find((node) => node.id === source); - const targetNode = nodes.find((node) => node.id === target); - if (!sourceNode || !sourceHandleId || !targetNode || !targetHandleId) { + if (!sourceNode || !sourceHandleId || !targetHandleId) { return colorTokenToCssVar('base.500'); } - const sourceNodeTemplate = templates[sourceNode.data.type]; - - const isInvocationToInvocationEdge = isInvocationNode(sourceNode) && isInvocationNode(targetNode); - const outputFieldTemplate = sourceNodeTemplate?.outputs[sourceHandleId]; - const sourceType = isInvocationToInvocationEdge ? outputFieldTemplate?.type : undefined; + const sourceType = isConnectorNode(sourceNode) + ? resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates) + : templates[sourceNode.data.type]?.outputs[sourceHandleId]?.type; return sourceType ? getFieldColor(sourceType) : colorTokenToCssVar('base.500'); }); @@ -50,7 +48,7 @@ export const buildSelectEdgeLabel = ( target: string, targetHandleId: string | null | undefined ) => - createSelector(selectNodes, (nodes): string | null => { + createSelector(selectNodes, selectEdges, (nodes, edges): string | null => { const sourceNode = nodes.find((node) => node.id === source); const targetNode = nodes.find((node) => node.id === target); @@ -58,8 +56,12 @@ export const buildSelectEdgeLabel = ( return null; } - const sourceNodeTemplate = templates[sourceNode.data.type]; + const resolvedSource = isConnectorNode(sourceNode) ? resolveConnectorSource(sourceNode.id, nodes, edges) : null; + const sourceTemplate = + resolvedSource !== null + ? templates[nodes.find((node) => node.id === resolvedSource.nodeId)?.data.type ?? ''] + : templates[sourceNode.data.type]; const targetNodeTemplate = templates[targetNode.data.type]; - return `${sourceNodeTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; + return `${sourceTemplate?.title || sourceNode.data?.label} -> ${targetNodeTemplate?.title || targetNode.data?.label}`; }); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx new file mode 100644 index 0000000000..a971efb397 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Connector/ConnectorNode.tsx @@ -0,0 +1,97 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Box, Icon } from '@invoke-ai/ui-library'; +import type { Node, NodeProps } from '@xyflow/react'; +import { Handle, Position } from '@xyflow/react'; +import NonInvocationNodeWrapper from 'features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { NO_DRAG_CLASS } from 'features/nodes/types/constants'; +import type { ConnectorNodeData } from 'features/nodes/types/invocation'; +import type { CSSProperties } from 'react'; +import { memo } from 'react'; +import { PiDotOutlineFill } from 'react-icons/pi'; + +const handleVisualSx = { + w: 3, + h: 3, + borderRadius: 'full', + borderWidth: 2, + borderColor: 'base.900', + bg: 'base.100', + pointerEvents: 'none', +} satisfies SystemStyleObject; + +const handleStyles = { + position: 'absolute', + width: '1rem', + height: '1rem', + display: 'flex', + alignItems: 'center', + justifyContent: 'center', + top: '50%', + transform: 'translateY(-50%)', + zIndex: 1, + background: 'none', + border: 'none', +} satisfies CSSProperties; + +const inputHandleStyles = { + ...handleStyles, + insetInlineStart: 0, + justifyContent: 'flex-start', +} satisfies CSSProperties; + +const outputHandleStyles = { + ...handleStyles, + insetInlineEnd: 0, + justifyContent: 'flex-end', +} satisfies CSSProperties; + +const ConnectorNode = ({ id, selected }: NodeProps>) => { + return ( + + + + + + + + + + + + + + ); +}; + +export default memo(ConnectorNode); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx index 7e2cde7093..84246ba43b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/NonInvocationNodeWrapper.tsx @@ -16,10 +16,12 @@ type NonInvocationNodeWrapperProps = PropsWithChildren & { nodeId: string; selected: boolean; width?: ChakraProps['w']; + borderRadius?: ChakraProps['borderRadius']; + withChrome?: boolean; }; const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { - const { nodeId, width, children, selected } = props; + const { nodeId, width, children, selected, borderRadius = 'base', withChrome = true } = props; const mouseOverNode = useMouseOverNode(nodeId); const zoomToNode = useZoomToNode(nodeId); @@ -62,14 +64,15 @@ const NonInvocationNodeWrapper = (props: NonInvocationNodeWrapperProps) => { onMouseOut={mouseOverNode.handleMouseOut} className={DRAG_HANDLE_CLASSNAME} sx={containerSx} + borderRadius={borderRadius} width={width || NODE_WIDTH} opacity={opacity} data-is-selected={selected} > - - + {withChrome && } + {withChrome && } {children} - + {withChrome && } ); }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts index 70e56cb4db..15624780c2 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/common/shared.ts @@ -6,7 +6,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; export const containerSx: SystemStyleObject = { h: 'full', position: 'relative', - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', cursor: 'grab', '--border-color': 'var(--invoke-colors-base-500)', @@ -30,7 +30,7 @@ export const containerSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', transitionProperty: 'none', pointerEvents: 'none', shadow: '0 0 0 1px var(--border-color)', @@ -64,7 +64,7 @@ export const shadowsSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'base', + borderRadius: 'inherit', pointerEvents: 'none', zIndex: -1, shadow: 'var(--invoke-shadows-xl), var(--invoke-shadows-base), var(--invoke-shadows-base)', @@ -76,7 +76,7 @@ export const inProgressSx: SystemStyleObject = { insetInlineEnd: 0, bottom: 0, insetInlineStart: 0, - borderRadius: 'md', + borderRadius: 'inherit', pointerEvents: 'none', transitionProperty: 'none', opacity: 0.7, diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index 591aac6c18..428fcf1e82 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -3,6 +3,7 @@ import { $templates } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; import { NODE_WIDTH } from 'features/nodes/types/constants'; import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode'; import { buildInvocationNode } from 'features/nodes/util/node/buildInvocationNode'; import { buildNotesNode } from 'features/nodes/util/node/buildNotesNode'; @@ -14,7 +15,7 @@ export const useBuildNode = () => { return useCallback( // string here is "any invocation type" - (type: string | 'current_image' | 'notes'): AnyNode => { + (type: string | 'connector' | 'current_image' | 'notes'): AnyNode => { const flow = $flow.get(); assert(flow !== null); @@ -42,6 +43,10 @@ export const useBuildNode = () => { return buildNotesNode(position); } + if (type === 'connector') { + return buildConnectorNode(position); + } + // TODO: Keep track of invocation types so we do not need to cast this // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. const template = templates[type] as InvocationTemplate; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts index bfd8be95f5..d763254bc4 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnection.ts @@ -12,9 +12,15 @@ import { edgesChanged, } from 'features/nodes/store/nodesSlice'; import { selectNodes, selectNodesSlice } from 'features/nodes/store/selectors'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import type { AnyEdge } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; import { useCallback, useMemo } from 'react'; import { assert } from 'tsafe'; @@ -33,6 +39,47 @@ export const useConnection = () => { return; } + if (isConnectorNode(node)) { + if (handleType === 'source' && handleId !== CONNECTOR_OUTPUT_HANDLE) { + return; + } + if (handleType === 'target' && handleId !== CONNECTOR_INPUT_HANDLE) { + return; + } + + const resolvedSourceType = + handleType === 'source' + ? resolveConnectorSourceFieldType(nodeId, nodes, selectNodesSlice(store.getState()).edges, templates) + : null; + $pendingConnection.set({ + nodeId, + handleId, + handleType, + fieldTemplate: + handleType === 'source' + ? { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: resolvedSourceType ?? { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + } + : { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { name: 'AnyField', cardinality: 'SINGLE', batch: false }, + }, + }); + return; + } + const template = templates[node.data.type]; if (!template) { return; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts new file mode 100644 index 0000000000..5306347921 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.test.ts @@ -0,0 +1,160 @@ +import { deepClone } from 'common/util/deepClone'; +import { buildConnectorNode } from 'features/nodes/util/node/buildConnectorNode'; +import { describe, expect, it } from 'vitest'; + +import { connectorInserted, nodesChanged, nodesSliceConfig } from './nodesSlice'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from './util/connectorTopology'; +import { add, buildEdge, buildNode, sub } from './util/testUtils'; + +const buildFixedConnectorNode = (id: string) => { + const connectorNode = buildConnectorNode({ x: 0, y: 0 }); + return { + ...connectorNode, + id, + data: { + ...connectorNode.data, + id, + }, + }; +}; + +describe('nodesSlice connector actions', () => { + it('splits a direct edge into source -> connector -> target edges when inserting a connector', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + const directEdge = buildEdge(source.id, 'value', target.id, 'a'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, target]; + initialState.edges = [directEdge]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + connectorInserted({ + edgeId: directEdge.id, + connector, + }) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id, connector.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]); + }); + + it('splices connector outputs back to the resolved upstream source when removed', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); + + it('splices one connector source back to multiple downstream targets when removed', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, targetA, targetB]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'b'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, targetA.id, targetB.id]); + expect(nextState.edges).toEqual([ + buildEdge(source.id, 'value', targetA.id, 'a'), + buildEdge(source.id, 'value', targetB.id, 'b'), + ]); + }); + + it('does not create any edges when removing a connector with no downstream targets', () => { + const source = buildNode(add); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector]; + initialState.edges = [buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id]); + expect(nextState.edges).toEqual([]); + }); + + it('removes a connector while preserving downstream connector edges in a chained splice case', () => { + const source = buildNode(add); + const connectorA = buildFixedConnectorNode('connector-a'); + const connectorB = buildFixedConnectorNode('connector-b'); + const target = buildNode(sub); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connectorA, connectorB, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connectorA.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, connectorB.id, target.id]); + expect(nextState.edges).toHaveLength(2); + expect(nextState.edges).toEqual( + expect.arrayContaining([ + buildEdge(source.id, 'value', connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]) + ); + }); + + it('splices connector edges when the connector is removed through generic node removal', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildFixedConnectorNode('connector-1'); + + const initialState = deepClone(nodesSliceConfig.slice.reducer(undefined, { type: 'test/init' })); + initialState.nodes = [source, connector, target]; + initialState.edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + const nextState = nodesSliceConfig.slice.reducer( + initialState, + nodesChanged([{ type: 'remove', id: connector.id }]) + ); + + expect(nextState.nodes.map((node) => node.id)).toEqual([source.id, target.id]); + expect(nextState.edges).toEqual([buildEdge(source.id, 'value', target.id, 'a')]); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index bdab6c1ae3..6713ee8fb4 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -20,6 +20,13 @@ import { reparentElement, } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import { type NodesState, zNodesState } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorOutputEdges, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; +import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, @@ -65,8 +72,8 @@ import { zStringGeneratorFieldValue, zStylePresetFieldValue, } from 'features/nodes/types/field'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import type { AnyEdge, AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { BuilderForm, ContainerElement, @@ -103,7 +110,7 @@ export const getInitialWorkflow = (): Omit[]>) => { + const removedConnectorSpliceEdges: AnyEdge[] = action.payload.flatMap((change) => { + if (change.type !== 'remove') { + return []; + } + + const node = state.nodes.find((candidate) => candidate.id === change.id); + if (!isConnectorNode(node)) { + return []; + } + + const resolvedSource = resolveConnectorSource(node.id, state.nodes, state.edges); + if (!resolvedSource) { + return []; + } + + return getConnectorOutputEdges(node.id, state.edges) + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => + connectionToEdge({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + }) + ); + }); + // TODO(psyche): The below TS issue was recently fixed upstream. Need to upgrade @xyflow/react and then we // should be able to remove this cast. // @@ -206,6 +240,12 @@ const slice = createSlice({ if (edgeChanges.length > 0) { state.edges = applyEdgeChanges(edgeChanges, state.edges); } + if (removedConnectorSpliceEdges.length > 0) { + state.edges = applyEdgeChanges( + removedConnectorSpliceEdges.map((edge) => ({ type: 'add', item: edge })), + state.edges + ); + } } const wereNodesRemoved = action.payload.some((change) => change.type === 'remove' || change.type === 'replace'); @@ -396,11 +436,40 @@ const slice = createSlice({ } } }, + connectorInserted: ( + state, + action: PayloadAction<{ + edgeId: string; + connector: ConnectorNode; + }> + ) => { + const { edgeId, connector } = action.payload; + const edge = state.edges.find((candidate) => candidate.id === edgeId); + if (!edge || edge.type !== 'default') { + return; + } + state.nodes.push({ ...SHARED_NODE_PROPERTIES, ...connector } as (typeof state.nodes)[number]); + state.edges = state.edges.filter((candidate) => candidate.id !== edgeId); + state.edges.push( + connectionToEdge({ + source: edge.source, + sourceHandle: edge.sourceHandle ?? null, + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }), + connectionToEdge({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: edge.target, + targetHandle: edge.targetHandle ?? null, + }) + ); + }, nodeLabelChanged: (state, action: PayloadAction<{ nodeId: string; label: string }>) => { const { nodeId, label } = action.payload; const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId); const node = state.nodes?.[nodeIndex]; - if (isInvocationNode(node) || isNotesNode(node)) { + if (isInvocationNode(node) || isNotesNode(node) || isConnectorNode(node)) { node.data.label = label; } }, @@ -614,6 +683,7 @@ export const { nodeEditorReset, nodeIsIntermediateChanged, nodeIsOpenChanged, + connectorInserted, nodeLabelChanged, nodeNotesChanged, nodesChanged, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts new file mode 100644 index 0000000000..e87ebcde79 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.test.ts @@ -0,0 +1,126 @@ +import type { AnyNode, ConnectorNode } from 'features/nodes/types/invocation'; +import { describe, expect, it } from 'vitest'; + +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, + getConnectorInputEdge, + getConnectorOutputEdges, + resolveConnectorSource, + resolveConnectorSourceFieldType, +} from './connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from './testUtils'; + +const buildConnectorNode = (id: string): ConnectorNode => ({ + id, + type: 'connector', + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector', + label: 'Connector', + isOpen: true, + }, +}); + +describe('connectorTopology', () => { + it('resolves the effective upstream source through one connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(resolveConnectorSource(connector.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + expect(resolveConnectorSourceFieldType(connector.id, nodes, edges, templates)).toEqual(add.outputs.value?.type); + }); + + it('resolves the effective upstream source through chained connectors', () => { + const source = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [source, connectorA, connectorB]; + const edges = [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toEqual({ + nodeId: source.id, + fieldName: 'value', + }); + }); + + it('returns no source or type for an unresolved connector chain', () => { + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const nodes: AnyNode[] = [connectorA, connectorB]; + const edges = [buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE)]; + + expect(resolveConnectorSource(connectorB.id, nodes, edges)).toBe(null); + expect(resolveConnectorSourceFieldType(connectorB.id, nodes, edges, templates)).toBe(null); + }); + + it('enumerates multiple outgoing edges for a connector', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const incoming = buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE); + const outgoingA = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'); + const outgoingB = buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'); + const edges = [incoming, outgoingA, outgoingB]; + + expect(getConnectorInputEdge(connector.id, edges)).toEqual(incoming); + expect(getConnectorOutputEdges(connector.id, edges)).toEqual([outgoingA, outgoingB]); + }); + + it('rejects connector deletion splice-through when any downstream target would be invalid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(img_resize); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'image'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); + + it('builds connector deletion splice-through edges when every downstream target remains valid', () => { + const source = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [source, connector, target]; + const edges = [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toEqual([ + { + source: source.id, + sourceHandle: 'value', + target: target.id, + targetHandle: 'a', + }, + ]); + }); + + it('returns no splice-through edges when a connector has downstream targets but no upstream source', () => { + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const nodes: AnyNode[] = [connector, target]; + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + + expect(getConnectorDeletionSpliceConnections(connector.id, nodes, edges, templates)).toBe(null); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts new file mode 100644 index 0000000000..e1267763d7 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/connectorTopology.ts @@ -0,0 +1,228 @@ +import type { Templates } from 'features/nodes/store/types'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; + +export const CONNECTOR_INPUT_HANDLE = 'in'; +export const CONNECTOR_OUTPUT_HANDLE = 'out'; + +type ResolvedConnectorSource = { + nodeId: string; + fieldName: string; +}; + +type SpliceConnection = { + source: string; + sourceHandle: string; + target: string; + targetHandle: string; +}; + +type SpliceConnectionValidator = ( + connection: SpliceConnection, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + ignoreEdge: AnyEdge | null, + strict?: boolean +) => string | null; + +export const getConnectorInputEdge = (connectorId: string, edges: AnyEdge[]): AnyEdge | undefined => + edges.find( + (edge) => + edge.type === 'default' && + edge.target === connectorId && + edge.targetHandle === CONNECTOR_INPUT_HANDLE && + typeof edge.sourceHandle === 'string' + ); + +export const getConnectorOutputEdges = (connectorId: string, edges: AnyEdge[]): AnyEdge[] => + edges.filter( + (edge) => + edge.type === 'default' && + edge.source === connectorId && + edge.sourceHandle === CONNECTOR_OUTPUT_HANDLE && + typeof edge.targetHandle === 'string' + ); + +export const resolveConnectorSource = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[] +): ResolvedConnectorSource | null => { + const visited = new Set(); + + const resolve = (nodeId: string): ResolvedConnectorSource | null => { + if (visited.has(nodeId)) { + return null; + } + visited.add(nodeId); + + const incomingEdge = getConnectorInputEdge(nodeId, edges); + if (!incomingEdge || incomingEdge.type !== 'default') { + return null; + } + if (typeof incomingEdge.sourceHandle !== 'string') { + return null; + } + + const sourceNode = nodes.find((node) => node.id === incomingEdge.source); + if (!sourceNode) { + return null; + } + + if (isInvocationNode(sourceNode)) { + return { nodeId: sourceNode.id, fieldName: incomingEdge.sourceHandle }; + } + + if (isConnectorNode(sourceNode)) { + return resolve(sourceNode.id); + } + + return null; + }; + + return resolve(connectorId); +}; + +export const resolveConnectorSourceFieldType = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): FieldType | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const sourceNode = nodes.find((node) => node.id === resolvedSource.nodeId); + if (!sourceNode || !isInvocationNode(sourceNode)) { + return null; + } + + const sourceTemplate = templates[sourceNode.data.type]; + return sourceTemplate?.outputs[resolvedSource.fieldName]?.type ?? null; +}; + +export const getConnectorDeletionSpliceConnections = ( + connectorId: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates, + validateConnection?: SpliceConnectionValidator +): SpliceConnection[] | null => { + const resolvedSource = resolveConnectorSource(connectorId, nodes, edges); + if (!resolvedSource) { + return null; + } + + const outputEdges = getConnectorOutputEdges(connectorId, edges); + const spliceConnections = outputEdges + .filter((edge): edge is AnyEdge & { type: 'default'; targetHandle: string } => edge.type === 'default') + .map((edge) => ({ + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + target: edge.target, + targetHandle: edge.targetHandle, + })); + + const deduped = new Set(); + for (const connection of spliceConnections) { + const key = `${connection.source}:${connection.sourceHandle}->${connection.target}:${connection.targetHandle}`; + if (deduped.has(key)) { + return null; + } + deduped.add(key); + } + + if (!validateConnection) { + const sourceType = resolveConnectorSourceFieldType(connectorId, nodes, edges, templates); + if (!sourceType) { + return null; + } + const inputEdgeId = getConnectorInputEdge(connectorId, edges)?.id; + const outputEdgeIds = new Set(outputEdges.map((edge) => edge.id)); + + for (const connection of spliceConnections) { + const targetNode = nodes.find((node) => node.id === connection.target); + if (!targetNode || !isInvocationNode(targetNode)) { + return null; + } + const targetTemplate = templates[targetNode.data.type]; + const targetFieldTemplate = targetTemplate?.inputs[connection.targetHandle]; + if (!targetFieldTemplate) { + return null; + } + + const matchesExistingDirectEdge = edges.some( + (edge) => + edge.type === 'default' && + edge.source === connection.source && + edge.sourceHandle === connection.sourceHandle && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if (matchesExistingDirectEdge) { + return null; + } + + const targetConflictCount = spliceConnections.filter( + (candidate) => candidate.target === connection.target && candidate.targetHandle === connection.targetHandle + ).length; + const existingTargetConflict = edges.some( + (edge) => + edge.type === 'default' && + edge.id !== inputEdgeId && + !outputEdgeIds.has(edge.id) && + edge.target === connection.target && + edge.targetHandle === connection.targetHandle + ); + if ( + targetFieldTemplate.type.name !== 'CollectionItemField' && + (targetConflictCount > 1 || existingTargetConflict) + ) { + return null; + } + + if ( + sourceType.name !== targetFieldTemplate.type.name && + targetFieldTemplate.type.name !== 'CollectionItemField' + ) { + return null; + } + } + + return spliceConnections; + } + + const ignoredEdgeIds = new Set([ + getConnectorInputEdge(connectorId, edges)?.id, + ...outputEdges.map((edge) => edge.id), + ]); + const existingEdges = edges.filter((edge) => !ignoredEdgeIds.has(edge.id)); + const stagedConnections: SpliceConnection[] = []; + + for (const connection of spliceConnections) { + const stagedEdges = [ + ...existingEdges, + ...stagedConnections.map( + ({ source, sourceHandle, target, targetHandle }) => + ({ + id: `splice-${source}-${sourceHandle}-${target}-${targetHandle}`, + type: 'default', + source, + sourceHandle, + target, + targetHandle, + }) satisfies AnyEdge + ), + ]; + if (validateConnection(connection, nodes, stagedEdges, templates, null, true) !== null) { + return null; + } + stagedConnections.push(connection); + } + + return spliceConnections; +}; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts index 288a4f9066..b4374a920c 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.test.ts @@ -1,13 +1,26 @@ import { deepClone } from 'common/util/deepClone'; import { unset } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { getFirstValidConnection, getSourceCandidateFields, getTargetCandidateFields, } from 'features/nodes/store/util/getFirstValidConnection'; -import { add, buildEdge, buildNode, img_resize, templates } from 'features/nodes/store/util/testUtils'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; import { describe, expect, it } from 'vitest'; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe('getFirstValidConnection', () => { it('should return null if the pending and candidate nodes are the same node', () => { const n = buildNode(add); @@ -120,6 +133,33 @@ describe('getFirstValidConnection', () => { expect(r).toEqual(null); }); }); + + it('should resolve connector target candidates when connecting an invocation output to a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + expect(getFirstValidConnection(n1.id, 'value', connector.id, null, [n1, connector], [], templates, null)).toEqual({ + source: n1.id, + sourceHandle: 'value', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + }); + + it('should resolve connector source candidates when connecting a connector to a typed invocation input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + + expect( + getFirstValidConnection(connector.id, null, n2.id, 'width', [n1, connector, n2], edges, templates, null) + ).toEqual({ + source: connector.id, + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: n2.id, + targetHandle: 'width', + }); + }); }); describe('getTargetCandidateFields', () => { @@ -160,6 +200,62 @@ describe('getTargetCandidateFields', () => { const r = getTargetCandidateFields(n1.id, 'width', n2.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); }); + it('should return the connector input handle when the target is a connector', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = getTargetCandidateFields(n1.id, 'value', connector.id, [n1, connector], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_INPUT_HANDLE]); + }); + it('should advertise typed target candidates for an unresolved connector output when no downstream constraint exists', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n2], + [], + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['a', 'b']); + }); + it('should only advertise compatible typed target candidates for an unresolved connector output with downstream constraints', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = getTargetCandidateFields( + connector.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r.map((field) => field.name)).toEqual(['width', 'height']); + }); + it('should resolve chained connector sources like the direct upstream source', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(img_resize); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = getTargetCandidateFields( + connectorB.id, + CONNECTOR_OUTPUT_HANDLE, + n2.id, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual([img_resize.inputs['width'], img_resize.inputs['height']]); + }); }); describe('getSourceCandidateFields', () => { @@ -200,4 +296,18 @@ describe('getSourceCandidateFields', () => { const r = getSourceCandidateFields(n2.id, 'width', n1.id, nodes, [], templates, edgePendingUpdate); expect(r).toEqual([img_resize.outputs['width'], img_resize.outputs['height']]); }); + it('should return the connector output handle when the source is a connector with a typed upstream source', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [n1, connector, n2], edges, templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); + it('should return a target-constrained connector source candidate when the connector chain is unresolved', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const r = getSourceCandidateFields(n2.id, 'width', connector.id, [connector, n2], [], templates, null); + expect(r.map((field) => field.name)).toEqual([CONNECTOR_OUTPUT_HANDLE]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts index 0b5aa17e17..2a466aae44 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/getFirstValidConnection.ts @@ -1,9 +1,15 @@ import type { Connection } from '@xyflow/react'; import { map } from 'es-toolkit/compat'; import type { Templates } from 'features/nodes/store/types'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSourceFieldType, +} from 'features/nodes/store/util/connectorTopology'; import { validateConnection } from 'features/nodes/store/util/validateConnection'; import type { FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import { isConnectorNode } from 'features/nodes/types/invocation'; /** * @@ -91,16 +97,43 @@ export const getTargetCandidateFields = ( return []; } - const sourceTemplate = templates[sourceNode.data.type]; + if (isConnectorNode(targetNode)) { + const candidate = { + name: CONNECTOR_INPUT_HANDLE, + title: 'Connector Input', + description: '', + fieldKind: 'input', + input: 'connection', + required: false, + default: undefined, + ui_hidden: false, + type: { + name: 'AnyField', + cardinality: 'SINGLE', + batch: false, + }, + } satisfies FieldInputTemplate; + + const c = { source, sourceHandle, target, targetHandle: candidate.name }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!targetTemplate) { return []; } - const sourceField = sourceTemplate.outputs[sourceHandle]; + if (!isConnectorNode(sourceNode)) { + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return []; + } - if (!sourceField) { - return []; + const sourceField = sourceTemplate.outputs[sourceHandle]; + + if (!sourceField) { + return []; + } } const targetCandidateFields = map(targetTemplate.inputs).filter((field) => { @@ -127,15 +160,45 @@ export const getSourceCandidateFields = ( return []; } + if (isConnectorNode(sourceNode)) { + const sourceFieldType = resolveConnectorSourceFieldType(sourceNode.id, nodes, edges, templates); + const targetTemplate = !isConnectorNode(targetNode) ? templates[targetNode.data.type] : null; + const targetFieldType = targetTemplate?.inputs[targetHandle]?.type; + const candidateType = sourceFieldType ?? targetFieldType; + if (!candidateType) { + return []; + } + + const candidate = { + name: CONNECTOR_OUTPUT_HANDLE, + title: 'Connector Output', + description: '', + fieldKind: 'output', + ui_hidden: false, + type: candidateType, + } satisfies FieldOutputTemplate; + + const c = { source, sourceHandle: candidate.name, target, targetHandle }; + return validateConnection(c, nodes, edges, templates, edgePendingUpdate, true) === null ? [candidate] : []; + } + const sourceTemplate = templates[sourceNode.data.type]; - const targetTemplate = templates[targetNode.data.type]; - if (!sourceTemplate || !targetTemplate) { + if (!sourceTemplate) { return []; } - const targetField = targetTemplate.inputs[targetHandle]; + if (!isConnectorNode(targetNode)) { + const targetTemplate = templates[targetNode.data.type]; + if (!targetTemplate) { + return []; + } - if (!targetField) { + const targetField = targetTemplate.inputs[targetHandle]; + + if (!targetField) { + return []; + } + } else if (targetHandle !== CONNECTOR_INPUT_HANDLE) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts new file mode 100644 index 0000000000..b70eda4bda --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.test.ts @@ -0,0 +1,23 @@ +import { describe, expect, it } from 'vitest'; + +import { connectionToEdge } from './reactFlowUtil'; + +describe('connectionToEdge', () => { + it('creates a default edge with the expected id and endpoints', () => { + expect( + connectionToEdge({ + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + }) + ).toEqual({ + type: 'default', + source: 'source-node', + sourceHandle: 'value', + target: 'target-node', + targetHandle: 'a', + id: 'reactflow__edge-source-nodevalue-target-nodea', + }); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts index 67264bf41b..3eaece154f 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/reactFlowUtil.ts @@ -24,6 +24,7 @@ export const connectionToEdge = (connection: Connection): AnyEdge => { const { source, sourceHandle, target, targetHandle } = connection; assert(source && sourceHandle && target && targetHandle, 'Invalid connection'); return { + type: 'default', source, sourceHandle, target, diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 88eae8484f..29395a1e1d 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -3,6 +3,11 @@ import { set } from 'es-toolkit/compat'; import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { describe, expect, it } from 'vitest'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + getConnectorDeletionSpliceConnections, +} from './connectorTopology'; import { add, buildEdge, buildNode, collect, img_resize, main_model_loader, sub, templates } from './testUtils'; import { validateConnection } from './validateConnection'; @@ -139,6 +144,18 @@ const integerCollectionOutputTemplate: InvocationTemplate = { classification: 'stable', }; +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + describe(validateConnection.name, () => { it('should reject invalid connection to self', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; @@ -458,6 +475,214 @@ describe(validateConnection.name, () => { expect(r).toEqual('nodes.connectionWouldCreateCycle'); }); + describe('connectors', () => { + it('should accept invocation output to connector input', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const r = validateConnection( + { source: n1.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, connector], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject a second input into a connector', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.inputMayOnlyHaveOneConnection'); + }); + + it('should accept connector output to invocation input when the upstream type matches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject connector output to invocation input when the upstream type mismatches', () => { + const n1 = buildNode(add); + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(img_resize); + const edges = [buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE)]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [n1, connector, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should accept unresolved connector output to a typed invocation input as the first downstream constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n2 = buildNode(sub); + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [connector, n2], + [], + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject unresolved connector output when it conflicts with an existing downstream typed constraint', () => { + const connector = buildConnectorNode('connector-1'); + const n1 = buildNode(sub); + const n2 = buildNode(img_resize); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n1.id, 'a')]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'image' }, + [connector, n1, n2], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connecting an incompatible upstream source into a connector with downstream typed constraints', () => { + const source = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const target = buildNode(sub); + const edges = [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]; + const r = validateConnection( + { source: source.id, sourceHandle: 'vae', target: connector.id, targetHandle: CONNECTOR_INPUT_HANDLE }, + [source, connector, target], + edges, + templates, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should preserve type information through chained connectors', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connectorB.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'a' }, + [n1, connectorA, connectorB, n2], + edges, + templates, + null + ); + expect(r).toEqual(null); + }); + + it('should reject cycles routed through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'value', target: n1.id, targetHandle: 'a' }, + [n1, n2, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.connectionWouldCreateCycle'); + }); + + it('should preserve collect item validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(collect); + const n3 = buildNode(main_model_loader); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', n2.id, 'item'), + buildEdge(n3.id, 'vae', connector.id, CONNECTOR_INPUT_HANDLE), + ]; + const r = validateConnection( + { source: connector.id, sourceHandle: CONNECTOR_OUTPUT_HANDLE, target: n2.id, targetHandle: 'item' }, + [n1, n2, n3, connector], + edges, + templates, + null + ); + expect(r).toEqual('nodes.cannotMixAndMatchCollectionItemTypes'); + }); + + it('should preserve if branch validation through connectors', () => { + const n1 = buildNode(add); + const n2 = buildNode(img_resize); + const n3 = buildNode(ifTemplate); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n3.id, 'true_input'), + ]; + const r = validateConnection( + { source: n2.id, sourceHandle: 'image', target: n3.id, targetHandle: 'false_input' }, + [n1, n2, n3, connector], + edges, + { ...templates, if: ifTemplate }, + null + ); + expect(r).toEqual('nodes.fieldTypesMustMatch'); + }); + + it('should reject connector deletion splice-through when it would duplicate an existing direct edge', () => { + const n1 = buildNode(add); + const n2 = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const edges = [ + buildEdge(n1.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(n1.id, 'value', n2.id, 'a'), + ]; + + expect(getConnectorDeletionSpliceConnections(connector.id, [n1, n2, connector], edges, templates)).toBe(null); + }); + + it('should reject connector deletion splice-through when fan-out would violate a single-input target', () => { + const n1 = buildNode(add); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const n2 = buildNode(sub); + const edges = [ + buildEdge(n1.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, n2.id, 'a'), + ]; + + expect( + getConnectorDeletionSpliceConnections(connectorA.id, [n1, connectorA, connectorB, n2], edges, templates) + ).toBe(null); + }); + }); + describe('non-strict mode', () => { it('should reject connections from self to self in non-strict mode', () => { const c = { source: 'add', sourceHandle: 'value', target: 'add', targetHandle: 'a' }; diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts index b342df064b..bb98d472d3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.ts @@ -1,10 +1,17 @@ import type { Connection as NullableConnection } from '@xyflow/react'; import type { Templates } from 'features/nodes/store/types'; import { areTypesEqual } from 'features/nodes/store/util/areTypesEqual'; +import { + CONNECTOR_INPUT_HANDLE, + CONNECTOR_OUTPUT_HANDLE, + resolveConnectorSource, +} from 'features/nodes/store/util/connectorTopology'; import { getCollectItemType } from 'features/nodes/store/util/getCollectItemType'; import { getHasCycles } from 'features/nodes/store/util/getHasCycles'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; -import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; +import type { FieldType } from 'features/nodes/types/field'; +import type { AnyEdge, AnyNode, InvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { SetNonNullable } from 'type-fest'; type Connection = SetNonNullable; @@ -18,6 +25,12 @@ type ValidateConnectionFunc = ( strict?: boolean ) => string | null; +type EffectiveSource = { + node: InvocationNode; + handle: string; + fieldTemplate: NonNullable['outputs'][string]; +}; + const getEqualityPredicate = (c: Connection) => (e: AnyEdge): boolean => { @@ -41,10 +54,7 @@ const isIfInputHandle = (handle: string): handle is (typeof IF_INPUT_HANDLES)[nu return IF_INPUT_HANDLES.includes(handle as (typeof IF_INPUT_HANDLES)[number]); }; -const isSingleCollectionPairOfSameBaseType = ( - firstType: { name: string; cardinality: string; batch: boolean }, - secondType: { name: string; cardinality: string; batch: boolean } -) => { +const isSingleCollectionPairOfSameBaseType = (firstType: FieldType, secondType: FieldType) => { const isSingleToCollection = firstType.cardinality === 'SINGLE' && secondType.cardinality === 'COLLECTION' && firstType.name === secondType.name; const isCollectionToSingle = @@ -52,6 +62,139 @@ const isSingleCollectionPairOfSameBaseType = ( return firstType.batch === secondType.batch && (isSingleToCollection || isCollectionToSingle); }; +const areFieldTypesCompatible = (firstType: FieldType, secondType: FieldType) => + validateConnectionTypes(firstType, secondType) || + validateConnectionTypes(secondType, firstType) || + isSingleCollectionPairOfSameBaseType(firstType, secondType); + +type ConnectorTerminalTargetEdge = AnyEdge & { + type: 'default'; + sourceHandle: string; + targetHandle: string; +}; + +const getConnectorTerminalTargetEdges = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const resolve = (currentConnectorId: string): ConnectorTerminalTargetEdge[] => { + if (visited.has(currentConnectorId)) { + return []; + } + visited.add(currentConnectorId); + + return edges.flatMap((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return []; + } + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + return resolve(targetNode.id); + } + + return [edge as ConnectorTerminalTargetEdge]; + }); + }; + + return resolve(connectorId); +}; + +const getConnectorSubgraphEdgeIds = (connectorId: string, nodes: AnyNode[], edges: AnyEdge[]) => { + const visited = new Set(); + const edgeIds = new Set(); + + const visit = (currentConnectorId: string) => { + if (visited.has(currentConnectorId)) { + return; + } + visited.add(currentConnectorId); + + edges.forEach((edge) => { + if ( + edge.type !== 'default' || + edge.source !== currentConnectorId || + edge.sourceHandle !== CONNECTOR_OUTPUT_HANDLE || + typeof edge.targetHandle !== 'string' + ) { + return; + } + + edgeIds.add(edge.id); + + const targetNode = nodes.find((node) => node.id === edge.target); + if (targetNode && isConnectorNode(targetNode)) { + visit(targetNode.id); + } + }); + }; + + visit(connectorId); + return edgeIds; +}; + +const getEffectiveSource = ( + sourceId: string, + sourceHandle: string, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + const sourceNode = nodes.find((n) => n.id === sourceId); + if (!sourceNode) { + return 'nodes.missingNode'; + } + + if (isConnectorNode(sourceNode)) { + if (sourceHandle !== CONNECTOR_OUTPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource) { + return null; + } + + return getEffectiveSource(resolvedSource.nodeId, resolvedSource.fieldName, nodes, edges, templates); + } + + if (!isInvocationNode(sourceNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceTemplate = templates[sourceNode.data.type]; + if (!sourceTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const sourceFieldTemplate = sourceTemplate.outputs[sourceHandle]; + if (!sourceFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + return { + node: sourceNode, + handle: sourceHandle, + fieldTemplate: sourceFieldTemplate, + }; +}; + +const getEffectiveSourceForEdge = ( + edge: AnyEdge, + nodes: AnyNode[], + edges: AnyEdge[], + templates: Templates +): EffectiveSource | 'nodes.missingNode' | 'nodes.missingInvocationTemplate' | 'nodes.missingFieldTemplate' | null => { + if (edge.type !== 'default' || typeof edge.sourceHandle !== 'string') { + return null; + } + + return getEffectiveSource(edge.source, edge.sourceHandle, nodes, edges, templates); +}; + /** * Validates a connection between two fields * @returns A translation key for an error if the connection is invalid, otherwise null @@ -83,18 +226,63 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotDuplicateConnection'; } - const sourceNode = nodes.find((n) => n.id === c.source); - if (!sourceNode) { - return 'nodes.missingNode'; - } - const targetNode = nodes.find((n) => n.id === c.target); + const sourceNode = nodes.find((n) => n.id === c.source); if (!targetNode) { return 'nodes.missingNode'; } - const sourceTemplate = templates[sourceNode.data.type]; - if (!sourceTemplate) { + const effectiveSource = getEffectiveSource(c.source, c.sourceHandle, nodes, filteredEdges, templates); + if (effectiveSource === 'nodes.missingNode') { + return 'nodes.missingNode'; + } + if (effectiveSource === 'nodes.missingInvocationTemplate') { + return 'nodes.missingInvocationTemplate'; + } + if (effectiveSource === 'nodes.missingFieldTemplate') { + return 'nodes.missingFieldTemplate'; + } + + if (isConnectorNode(targetNode)) { + if (c.targetHandle !== CONNECTOR_INPUT_HANDLE) { + return 'nodes.missingFieldTemplate'; + } + + if (filteredEdges.find(getTargetEqualityPredicate(c))) { + return 'nodes.inputMayOnlyHaveOneConnection'; + } + + if (effectiveSource) { + const connectorSubgraphEdgeIds = getConnectorSubgraphEdgeIds(targetNode.id, nodes, filteredEdges); + const stagedEdges = filteredEdges.filter((edge) => !connectorSubgraphEdgeIds.has(edge.id)); + const terminalTargetEdges = getConnectorTerminalTargetEdges(targetNode.id, nodes, filteredEdges); + + for (const terminalTargetEdge of terminalTargetEdges) { + const downstreamValidation = validateConnection( + { + source: c.source, + sourceHandle: c.sourceHandle, + target: terminalTargetEdge.target, + targetHandle: terminalTargetEdge.targetHandle, + }, + nodes, + stagedEdges, + templates, + null, + true + ); + + if (downstreamValidation !== null) { + return downstreamValidation; + } + } + } + + // Unresolved connector chains are allowed to terminate on another connector. + return null; + } + + if (!isInvocationNode(targetNode)) { return 'nodes.missingInvocationTemplate'; } @@ -103,11 +291,6 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.missingInvocationTemplate'; } - const sourceFieldTemplate = sourceTemplate.outputs[c.sourceHandle]; - if (!sourceFieldTemplate) { - return 'nodes.missingFieldTemplate'; - } - const targetFieldTemplate = targetTemplate.inputs[c.targetHandle]; if (!targetFieldTemplate) { return 'nodes.missingFieldTemplate'; @@ -117,23 +300,58 @@ export const validateConnection: ValidateConnectionFunc = ( return 'nodes.cannotConnectToDirectInput'; } + if (!effectiveSource) { + if (sourceNode && isConnectorNode(sourceNode) && c.sourceHandle === CONNECTOR_OUTPUT_HANDLE) { + const existingTerminalTargetEdges = getConnectorTerminalTargetEdges(sourceNode.id, nodes, filteredEdges).filter( + (edge) => !(edge.target === c.target && edge.targetHandle === c.targetHandle) + ); + + for (const terminalTargetEdge of existingTerminalTargetEdges) { + const constrainedTargetNode = nodes.find((node) => node.id === terminalTargetEdge.target); + if (!constrainedTargetNode || !isInvocationNode(constrainedTargetNode)) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetTemplate = templates[constrainedTargetNode.data.type]; + if (!constrainedTargetTemplate) { + return 'nodes.missingInvocationTemplate'; + } + + const constrainedTargetFieldTemplate = constrainedTargetTemplate.inputs[terminalTargetEdge.targetHandle]; + if (!constrainedTargetFieldTemplate) { + return 'nodes.missingFieldTemplate'; + } + + if (!areFieldTypesCompatible(constrainedTargetFieldTemplate.type, targetFieldTemplate.type)) { + return 'nodes.fieldTypesMustMatch'; + } + } + + return null; + } + + return 'nodes.fieldTypesMustMatch'; + } + + const { node: resolvedSourceNode, handle: sourceHandle, fieldTemplate: sourceFieldTemplate } = effectiveSource; + if (targetNode.data.type === 'collect' && c.targetHandle === 'item') { // Collect nodes shouldn't mix and match field types. - const collectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const collectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if (collectItemType && !areTypesEqual(sourceFieldTemplate.type, collectItemType)) { return 'nodes.cannotMixAndMatchCollectionItemTypes'; } } if ( - sourceNode.data.type === 'collect' && - c.sourceHandle === 'collection' && + resolvedSourceNode.data.type === 'collect' && + sourceHandle === 'collection' && targetNode.data.type === 'collect' && c.targetHandle === 'collection' ) { // Chained collect nodes should preserve a single item type when both ends are already typed. - const sourceCollectItemType = getCollectItemType(templates, nodes, edges, sourceNode.id); - const targetCollectItemType = getCollectItemType(templates, nodes, edges, targetNode.id); + const sourceCollectItemType = getCollectItemType(templates, nodes, filteredEdges, resolvedSourceNode.id); + const targetCollectItemType = getCollectItemType(templates, nodes, filteredEdges, targetNode.id); if ( sourceCollectItemType && targetCollectItemType && @@ -148,33 +366,24 @@ export const validateConnection: ValidateConnectionFunc = ( const siblingInputEdge = filteredEdges.find((e) => e.target === c.target && e.targetHandle === siblingHandle); if (siblingInputEdge) { - if (siblingInputEdge.source === null || siblingInputEdge.source === undefined) { + const siblingEffectiveSource = getEffectiveSourceForEdge(siblingInputEdge, nodes, filteredEdges, templates); + if (siblingEffectiveSource === 'nodes.missingNode') { return 'nodes.missingNode'; } - - if (siblingInputEdge.sourceHandle === null || siblingInputEdge.sourceHandle === undefined) { - return 'nodes.missingFieldTemplate'; - } - - const siblingSourceNode = nodes.find((n) => n.id === siblingInputEdge.source); - if (!siblingSourceNode) { - return 'nodes.missingNode'; - } - - const siblingSourceTemplate = templates[siblingSourceNode.data.type]; - if (!siblingSourceTemplate) { + if (siblingEffectiveSource === 'nodes.missingInvocationTemplate') { return 'nodes.missingInvocationTemplate'; } - - const siblingSourceFieldTemplate = siblingSourceTemplate.outputs[siblingInputEdge.sourceHandle]; - if (!siblingSourceFieldTemplate) { + if (siblingEffectiveSource === 'nodes.missingFieldTemplate') { return 'nodes.missingFieldTemplate'; } + if (!siblingEffectiveSource) { + return 'nodes.fieldTypesMustMatch'; + } const areIfInputTypesCompatible = - validateConnectionTypes(sourceFieldTemplate.type, siblingSourceFieldTemplate.type) || - validateConnectionTypes(siblingSourceFieldTemplate.type, sourceFieldTemplate.type) || - isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingSourceFieldTemplate.type); + validateConnectionTypes(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type) || + validateConnectionTypes(siblingEffectiveSource.fieldTemplate.type, sourceFieldTemplate.type) || + isSingleCollectionPairOfSameBaseType(sourceFieldTemplate.type, siblingEffectiveSource.fieldTemplate.type); if (!areIfInputTypesCompatible) { return 'nodes.fieldTypesMustMatch'; @@ -189,30 +398,17 @@ export const validateConnection: ValidateConnectionFunc = ( } } - if (sourceNode.data.type === 'if' && c.sourceHandle === 'value') { + if (resolvedSourceNode.data.type === 'if' && sourceHandle === 'value') { const ifInputEdges = filteredEdges.filter( - (e) => e.target === sourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) + (e) => + e.target === resolvedSourceNode.id && typeof e.targetHandle === 'string' && isIfInputHandle(e.targetHandle) ); const ifInputTypes = ifInputEdges.flatMap((edge) => { - if (edge.source === null || edge.source === undefined) { + const ifInputSource = getEffectiveSourceForEdge(edge, nodes, filteredEdges, templates); + if (!ifInputSource || typeof ifInputSource === 'string') { return []; } - if (edge.sourceHandle === null || edge.sourceHandle === undefined) { - return []; - } - const ifInputSourceNode = nodes.find((n) => n.id === edge.source); - if (!ifInputSourceNode) { - return []; - } - const ifInputSourceTemplate = templates[ifInputSourceNode.data.type]; - if (!ifInputSourceTemplate) { - return []; - } - const ifInputSourceFieldTemplate = ifInputSourceTemplate.outputs[edge.sourceHandle]; - if (!ifInputSourceFieldTemplate) { - return []; - } - return [ifInputSourceFieldTemplate.type]; + return [ifInputSource.fieldTemplate.type]; }); if (ifInputTypes.length > 0) { diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 8cd529deb7..9e82144e4e 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -43,6 +43,12 @@ export const zNotesNodeData = z.object({ isOpen: z.boolean(), notes: z.string(), }); +export const zConnectorNodeData = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + label: z.string(), + isOpen: z.boolean(), +}); const zCurrentImageNodeData = z.object({ id: z.string().trim().min(1), type: z.literal('current_image'), @@ -52,6 +58,7 @@ const zCurrentImageNodeData = z.object({ export type NotesNodeData = z.infer; export type InvocationNodeData = z.infer; +export type ConnectorNodeData = z.infer; type CurrentImageNodeData = z.infer; const zInvocationNodeValidationSchema = z.looseObject({ @@ -70,6 +77,15 @@ const zNotesNodeValidationSchema = z.looseObject({ const zNotesNode = z.custom>((val) => zNotesNodeValidationSchema.safeParse(val).success); export type NotesNode = z.infer; +const zConnectorNodeValidationSchema = z.looseObject({ + type: z.literal('connector'), + data: zConnectorNodeData, +}); +const zConnectorNode = z.custom>( + (val) => zConnectorNodeValidationSchema.safeParse(val).success +); +export type ConnectorNode = z.infer; + const zCurrentImageNodeValidationSchema = z.looseObject({ type: z.literal('current_image'), data: zCurrentImageNodeData, @@ -79,12 +95,14 @@ const zCurrentImageNode = z.custom>( ); export type CurrentImageNode = z.infer; -export const zAnyNode = z.union([zInvocationNode, zNotesNode, zCurrentImageNode]); +export const zAnyNode = z.union([zInvocationNode, zNotesNode, zConnectorNode, zCurrentImageNode]); export type AnyNode = z.infer; export const isInvocationNode = (node?: AnyNode | null): node is InvocationNode => Boolean(node && node.type === 'invocation'); export const isNotesNode = (node?: AnyNode | null): node is NotesNode => Boolean(node && node.type === 'notes'); +export const isConnectorNode = (node?: AnyNode | null): node is ConnectorNode => + Boolean(node && node.type === 'connector'); // #endregion // #region NodeExecutionState diff --git a/invokeai/frontend/web/src/features/nodes/types/workflow.ts b/invokeai/frontend/web/src/features/nodes/types/workflow.ts index 66e69ec585..34f98eb289 100644 --- a/invokeai/frontend/web/src/features/nodes/types/workflow.ts +++ b/invokeai/frontend/web/src/features/nodes/types/workflow.ts @@ -3,7 +3,7 @@ import { z } from 'zod'; import type { FieldType } from './field'; import { zFieldIdentifier } from './field'; -import { zInvocationNodeData, zNotesNodeData } from './invocation'; +import { zConnectorNodeData, zInvocationNodeData, zNotesNodeData } from './invocation'; // #region Workflow misc const zXYPosition = z @@ -31,7 +31,13 @@ const zWorkflowNotesNode = z.object({ data: zNotesNodeData, position: zXYPosition, }); -const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode]); +const zWorkflowConnectorNode = z.object({ + id: z.string().trim().min(1), + type: z.literal('connector'), + data: zConnectorNodeData, + position: zXYPosition, +}); +const zWorkflowNode = z.union([zWorkflowInvocationNode, zWorkflowNotesNode, zWorkflowConnectorNode]); type WorkflowInvocationNode = z.infer; @@ -377,7 +383,7 @@ export const zWorkflowV3 = z.object({ exposedFields: z.array(zFieldIdentifier), meta: z.object({ category: zWorkflowCategory.default('user'), - version: z.literal('3.0.0'), + version: z.literal('4.0.0'), }), // Use the validated form schema! form: zValidatedBuilderForm, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts new file mode 100644 index 0000000000..b44c6b38cd --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.test.ts @@ -0,0 +1,194 @@ +import { deepClone } from 'common/util/deepClone'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; +import { add, buildEdge, buildNode, img_resize, sub, templates } from 'features/nodes/store/util/testUtils'; +import { describe, expect, it } from 'vitest'; + +import { buildNodesGraph } from './buildNodesGraph'; + +const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, +}); + +const buildState = (nodes: unknown[], edges: unknown[]) => + ({ + nodes: { + present: { + _version: 1, + nodes, + edges, + formFieldInitialValues: {}, + id: undefined, + name: '', + author: '', + description: '', + version: '', + contact: '', + tags: '', + notes: '', + exposedFields: [], + meta: { version: '4.0.0', category: 'user' }, + form: { + rootElementId: 'root', + elements: { + root: { + id: 'root', + type: 'container', + data: { layout: 'column', children: [] }, + }, + }, + }, + }, + }, + gallery: { + autoAddBoardId: 'none', + selection: [], + }, + }) as unknown as Parameters[0]; + +describe('buildNodesGraph', () => { + it('flattens a single connector to one direct execution edge', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('flattens chained connectors transitively', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connectorA = buildConnectorNode('connector-a'); + const connectorB = buildConnectorNode('connector-b'); + const state = buildState( + [source, target, connectorA, connectorB], + [ + buildEdge(source.id, 'value', connectorA.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorA.id, CONNECTOR_OUTPUT_HANDLE, connectorB.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connectorB.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('fans out through a connector into multiple execution edges', () => { + const source = buildNode(add); + const targetA = buildNode(sub); + const targetB = buildNode(img_resize); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, targetA, targetB, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetA.id, 'a'), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, targetB.id, 'width'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetA.id, field: 'a' }, + }, + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: targetB.id, field: 'width' }, + }, + ]); + }); + + it('drops unresolved connector paths from the execution graph', () => { + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState([target, connector], [buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a')]); + + const graph = buildNodesGraph(state, templates); + + expect(graph.nodes).not.toHaveProperty(connector.id); + expect(graph.edges).toEqual([]); + }); + + it('deduplicates effective execution edges created by flattening', () => { + const source = buildNode(add); + const target = buildNode(sub); + const connector = buildConnectorNode('connector-1'); + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + buildEdge(source.id, 'value', target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + }); + + it('still omits explicit destination input values when the flattened edge exists', () => { + const source = buildNode(add); + const target = deepClone(buildNode(sub)); + const connector = buildConnectorNode('connector-1'); + const inputA = target.data.inputs.a; + expect(inputA).toBeDefined(); + if (!inputA) { + throw new Error('Missing input a'); + } + inputA.value = 'not-an-integer' as never; + const state = buildState( + [source, target, connector], + [ + buildEdge(source.id, 'value', connector.id, CONNECTOR_INPUT_HANDLE), + buildEdge(connector.id, CONNECTOR_OUTPUT_HANDLE, target.id, 'a'), + ] + ); + + const graph = buildNodesGraph(state, templates); + + expect(graph.edges).toEqual([ + { + source: { node_id: source.id, field: 'value' }, + destination: { node_id: target.id, field: 'a' }, + }, + ]); + expect(graph.nodes[target.id]).not.toHaveProperty('a'); + }); +}); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts index d83555e558..50052c806c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/buildNodesGraph.ts @@ -4,10 +4,11 @@ import { omit, reduce } from 'es-toolkit/compat'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { Templates } from 'features/nodes/store/types'; +import { resolveConnectorSource } from 'features/nodes/store/util/connectorTopology'; import type { BoardField } from 'features/nodes/types/common'; import type { BoardFieldInputInstance } from 'features/nodes/types/field'; import { isBoardFieldInputInstance, isBoardFieldInputTemplate } from 'features/nodes/types/field'; -import { isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isExecutableNode, isInvocationNode } from 'features/nodes/types/invocation'; import type { AnyInvocation, Graph } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -96,12 +97,57 @@ export const buildNodesGraph = (state: RootState, templates: Templates): Require const filteredNodeIds = filteredNodes.map(({ id }) => id); // skip out the "dummy" edges between collapsed nodes - const filteredEdges = edges - .filter((edge) => edge.type !== 'collapsed') - .filter((edge) => filteredNodeIds.includes(edge.source) && filteredNodeIds.includes(edge.target)); + const flattenedEdges = edges + .filter((edge) => edge.type === 'default') + .flatMap((edge) => { + const targetNode = nodes.find((node) => node.id === edge.target); + if (!targetNode || !isInvocationNode(targetNode) || !isExecutableNode(targetNode)) { + return []; + } + + const sourceNode = nodes.find((node) => node.id === edge.source); + if (!sourceNode) { + return []; + } + + if (isInvocationNode(sourceNode)) { + if (!isExecutableNode(sourceNode) || !filteredNodeIds.includes(sourceNode.id)) { + return []; + } + return [edge]; + } + + if (isConnectorNode(sourceNode)) { + const resolvedSource = resolveConnectorSource(sourceNode.id, nodes, edges); + if (!resolvedSource || !filteredNodeIds.includes(resolvedSource.nodeId)) { + return []; + } + return [ + { + ...edge, + id: `flattened-${resolvedSource.nodeId}-${resolvedSource.fieldName}-${edge.target}-${edge.targetHandle}`, + source: resolvedSource.nodeId, + sourceHandle: resolvedSource.fieldName, + }, + ]; + } + + return []; + }) + .filter((edge, index, allEdges) => { + return ( + allEdges.findIndex( + (candidate) => + candidate.source === edge.source && + candidate.sourceHandle === edge.sourceHandle && + candidate.target === edge.target && + candidate.targetHandle === edge.targetHandle + ) === index + ); + }); // Reduce the node editor edges into invocation graph edges - const parsedEdges = filteredEdges.reduce>((edgesAccumulator, edge) => { + const parsedEdges = flattenedEdges.reduce>((edgesAccumulator, edge) => { const { source, target, sourceHandle, targetHandle } = edge; if (!sourceHandle || !targetHandle) { diff --git a/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts new file mode 100644 index 0000000000..18cb45fb1c --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/util/node/buildConnectorNode.ts @@ -0,0 +1,21 @@ +import type { XYPosition } from '@xyflow/react'; +import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; +import type { ConnectorNode } from 'features/nodes/types/invocation'; +import { v4 as uuidv4 } from 'uuid'; + +export const buildConnectorNode = (position: XYPosition): ConnectorNode => { + const nodeId = uuidv4(); + const node: ConnectorNode = { + ...SHARED_NODE_PROPERTIES, + id: nodeId, + type: 'connector', + position, + data: { + id: nodeId, + type: 'connector', + isOpen: true, + label: 'Connector', + }, + }; + return node; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts index 17cd0a33a7..21f0c38009 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/buildWorkflow.ts @@ -5,7 +5,7 @@ import { parseify } from 'common/util/serialize'; import { pick } from 'es-toolkit/compat'; import { selectNodesSlice } from 'features/nodes/store/selectors'; import type { NodesState } from 'features/nodes/store/types'; -import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; +import { isConnectorNode, isInvocationNode, isNotesNode } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { zWorkflowV3 } from 'features/nodes/types/workflow'; import i18n from 'i18n'; @@ -42,6 +42,9 @@ export const buildWorkflowFast = (nodesState: NodesState): WorkflowV3 => { if (isInvocationNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); + } else if (isConnectorNode(node) && node.type) { + const { id, type, data, position } = node; + newWorkflow.nodes.push({ id, type, data, position }); } else if (isNotesNode(node) && node.type) { const { id, type, data, position } = node; newWorkflow.nodes.push({ id, type, data, position }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts index 05efc26fea..c09f4e1729 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts @@ -30,7 +30,7 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor description: '', meta: { category: 'user', - version: '3.0.0', + version: '4.0.0', }, notes: '', tags: '', diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 32971a02d0..638e66806d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -70,9 +70,11 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { }; const migrateV2toV3 = (workflowToMigrate: WorkflowV2): WorkflowV3 => { - // Bump version - (workflowToMigrate as unknown as WorkflowV3).meta.version = '3.0.0'; - // Parsing strips out any extra properties not in the latest version + return migrateV3toV4(workflowToMigrate as unknown as WorkflowV3); +}; + +const migrateV3toV4 = (workflowToMigrate: WorkflowV3): WorkflowV3 => { + workflowToMigrate.meta.version = '4.0.0'; return zWorkflowV3.parse(workflowToMigrate); }; @@ -100,6 +102,10 @@ export const parseAndMigrateWorkflow = (data: unknown): WorkflowV3 => { workflow = migrateV2toV3(v2); } + if (get(workflow, 'meta.version') === '3.0.0') { + workflow = migrateV3toV4(workflow as WorkflowV3); + } + // We should now have a V3 workflow const migratedWorkflow = zWorkflowV3.parse(workflow); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts index 1442a3475e..c1d0858831 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.test.ts @@ -1,4 +1,5 @@ import { get } from 'es-toolkit/compat'; +import { CONNECTOR_INPUT_HANDLE, CONNECTOR_OUTPUT_HANDLE } from 'features/nodes/store/util/connectorTopology'; import { img_resize, main_model_loader } from 'features/nodes/store/util/testUtils'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { getDefaultForm } from 'features/nodes/types/workflow'; @@ -7,6 +8,17 @@ import { describe, expect, it } from 'vitest'; //TODO(psyche): Test workflow validation for form builder fields describe('validateWorkflow', () => { + const buildConnectorNode = (id: string) => ({ + id, + type: 'connector' as const, + position: { x: 0, y: 0 }, + data: { + id, + type: 'connector' as const, + label: 'Connector', + isOpen: true, + }, + }); const getWorkflow = (): WorkflowV3 => ({ name: '', author: '', @@ -17,7 +29,7 @@ describe('validateWorkflow', () => { notes: '', exposedFields: [], form: getDefaultForm(), - meta: { version: '3.0.0', category: 'user' }, + meta: { version: '4.0.0', category: 'user' }, nodes: [ { id: '94b1d596-f2f2-4c1c-bd5b-a79c62d947ad', @@ -104,6 +116,7 @@ describe('validateWorkflow', () => { }); expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[1].data.inputs.image.value')).toBeUndefined(); + expect(validationResult.workflow.meta.version).toBe('4.0.0'); }); it('should reset boards that are inaccessible', async () => { const validationResult = await validateWorkflow({ @@ -127,4 +140,138 @@ describe('validateWorkflow', () => { expect(validationResult.warnings.length).toBe(1); expect(get(validationResult, 'workflow.nodes[0].data.inputs.model.value')).toBeUndefined(); }); + + it('should delete malformed connector edges with invalid handles', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: 'connector-1', + targetHandle: 'wrong', + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should delete connector edges with missing endpoints', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: 'missing-node', + sourceHandle: 'value', + target: 'connector-1', + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should repair invalid multi-input connector state predictably by keeping the first valid input edge', async () => { + const workflow = getWorkflow(); + const loader2 = structuredClone(workflow.nodes[0]!); + loader2.id = 'second-loader'; + loader2.data.id = 'second-loader'; + const connector = buildConnectorNode('connector-1'); + workflow.nodes.push(loader2, connector); + workflow.edges.push({ + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + workflow.edges.push({ + id: 'e2', + type: 'default', + source: loader2.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([ + { + id: 'e1', + type: 'default', + source: workflow.nodes[0]!.id, + sourceHandle: 'vae', + target: connector.id, + targetHandle: CONNECTOR_INPUT_HANDLE, + }, + ]); + expect(validationResult.warnings.length).toBe(1); + }); + + it('should retain isolated connectors during workflow validation', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.nodes.find((node) => node.id === 'connector-1')).toBeDefined(); + expect(validationResult.warnings).toEqual([]); + }); + + it('should retain unresolved connector output edges that establish downstream constraints in the editor', async () => { + const workflow = getWorkflow(); + workflow.nodes.push(buildConnectorNode('connector-1')); + const unresolvedEdge = { + id: 'e1', + type: 'default' as const, + source: 'connector-1', + sourceHandle: CONNECTOR_OUTPUT_HANDLE, + target: workflow.nodes[1]!.id, + targetHandle: 'image', + }; + workflow.edges.push(unresolvedEdge); + + const validationResult = await validateWorkflow({ + workflow, + templates: { img_resize, main_model_loader }, + checkImageAccess: resolveTrue, + checkBoardAccess: resolveTrue, + checkModelAccess: resolveTrue, + }); + + expect(validationResult.workflow.edges).toEqual([unresolvedEdge]); + expect(validationResult.warnings).toEqual([]); + }); }); diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts index b86870d450..448214defe 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/validateWorkflow.ts @@ -1,6 +1,7 @@ import { parseify } from 'common/util/serialize'; import { addElement, getIsFormEmpty } from 'features/nodes/components/sidePanel/builder/form-manipulation'; import type { Templates } from 'features/nodes/store/types'; +import { validateConnection } from 'features/nodes/store/util/validateConnection'; import { isBoardFieldInputInstance, isImageFieldCollectionInputInstance, @@ -149,8 +150,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise(); + const validEdges = []; for (const edge of edges) { // Validate each edge. If the edge is invalid, we must remove it to prevent runtime errors with reactflow. @@ -169,7 +169,7 @@ export const validateWorkflow = async (args: ValidateWorkflowArgs): Promise !edgesToDelete.has(id)); + _workflow.edges = validEdges; // Migrated exposed fields to form elements if they exist and the form does not // Note: If the form is invalid per its zod schema, it will be reset to a default, empty form! From 9d62bfdf8e45f8b7124b3b9d344d45961fe7f41b Mon Sep 17 00:00:00 2001 From: Valeri Che <38873282+DustyShoe@users.noreply.github.com> Date: Tue, 14 Apr 2026 03:15:29 +0300 Subject: [PATCH 2/6] Feature: Add optional setting to prune queue on startup (#8861) * Add more settings to invokeai.yaml for improved queue management. * Adjusted description * More logic tweaking * chore(api): update generated schema types * chore(api): update generated schema types * Add: UI element for max_queue_history to 'Settings' modal. Now it is possible to set Max queue history in both places: .yaml and UI. * chore(api): regenerate schema types * chore(api): normalize generated schema path defaults --------- Co-authored-by: dunkeroni --- invokeai/app/api/routers/app_info.py | 42 ++++++- .../app/services/config/config_default.py | 6 +- .../session_queue/session_queue_sqlite.py | 56 ++++++++- invokeai/frontend/web/public/locales/en.json | 2 + .../SettingsModal/SettingsModal.tsx | 110 +++++++++++++++++- .../web/src/services/api/endpoints/appInfo.ts | 21 ++++ .../frontend/web/src/services/api/schema.ts | 57 ++++++++- 7 files changed, 283 insertions(+), 11 deletions(-) diff --git a/invokeai/app/api/routers/app_info.py b/invokeai/app/api/routers/app_info.py index d8f3bb2f80..da777ebc73 100644 --- a/invokeai/app/api/routers/app_info.py +++ b/invokeai/app/api/routers/app_info.py @@ -6,8 +6,14 @@ from fastapi import Body from fastapi.routing import APIRouter from pydantic import BaseModel, Field +from invokeai.app.api.auth_dependencies import AdminUserOrDefault from invokeai.app.api.dependencies import ApiDependencies -from invokeai.app.services.config.config_default import InvokeAIAppConfig, get_config +from invokeai.app.services.config.config_default import ( + DefaultInvokeAIAppConfig, + InvokeAIAppConfig, + get_config, + load_and_migrate_config, +) from invokeai.app.services.invocation_cache.invocation_cache_common import InvocationCacheStatus from invokeai.backend.image_util.infill_methods.patchmatch import PatchMatch from invokeai.backend.util.logging import logging @@ -64,6 +70,16 @@ class InvokeAIAppConfigWithSetFields(BaseModel): config: InvokeAIAppConfig = Field(description="The InvokeAI App Config") +class UpdateAppGenerationSettingsRequest(BaseModel): + """Writable generation-related app settings.""" + + max_queue_history: int | None = Field( + default=None, + ge=0, + description="Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items.", + ) + + @app_router.get( "/runtime_config", operation_id="get_runtime_config", status_code=200, response_model=InvokeAIAppConfigWithSetFields ) @@ -72,6 +88,30 @@ async def get_runtime_config() -> InvokeAIAppConfigWithSetFields: return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) +@app_router.patch( + "/runtime_config", + operation_id="update_runtime_config", + status_code=200, + response_model=InvokeAIAppConfigWithSetFields, +) +async def update_runtime_config( + _: AdminUserOrDefault, + changes: UpdateAppGenerationSettingsRequest = Body(description="Writable runtime configuration changes"), +) -> InvokeAIAppConfigWithSetFields: + config = get_config() + update_dict = changes.model_dump(exclude_unset=True) + config.update_config(update_dict) + + if config.config_file_path.exists(): + persisted_config = load_and_migrate_config(config.config_file_path) + else: + persisted_config = DefaultInvokeAIAppConfig() + + persisted_config.update_config(update_dict) + persisted_config.write_file(config.config_file_path) + return InvokeAIAppConfigWithSetFields(set_fields=config.model_fields_set, config=config) + + @app_router.get( "/logging", operation_id="get_log_level", diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 5d1b1d0d8d..7e56db9f61 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -101,7 +101,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. max_queue_size: Maximum number of items in the session queue. - clear_queue_on_startup: Empties session queue on startup. + clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. + max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. allow_nodes: List of nodes to allow. Omit to allow all. deny_nodes: List of nodes to deny. Omit to deny none. node_cache_size: How many cached nodes to keep in memory. @@ -191,7 +192,8 @@ class InvokeAIAppConfig(BaseSettings): force_tiled_decode: bool = Field(default=False, description="Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty).") pil_compress_level: int = Field(default=1, description="The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting.") max_queue_size: int = Field(default=10000, gt=0, description="Maximum number of items in the session queue.") - clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup.") + clear_queue_on_startup: bool = Field(default=False, description="Empties session queue on startup. If true, disables `max_queue_history`.") + max_queue_history: Optional[int] = Field(default=None, ge=0, description="Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true.") # NODES allow_nodes: Optional[list[str]] = Field(default=None, description="List of nodes to allow. Omit to allow all.") diff --git a/invokeai/app/services/session_queue/session_queue_sqlite.py b/invokeai/app/services/session_queue/session_queue_sqlite.py index 070a7cef29..172dc08d55 100644 --- a/invokeai/app/services/session_queue/session_queue_sqlite.py +++ b/invokeai/app/services/session_queue/session_queue_sqlite.py @@ -45,10 +45,19 @@ class SqliteSessionQueue(SessionQueueBase): def start(self, invoker: Invoker) -> None: self.__invoker = invoker self._set_in_progress_to_canceled() - if self.__invoker.services.configuration.clear_queue_on_startup: + config = self.__invoker.services.configuration + if config.clear_queue_on_startup: clear_result = self.clear(DEFAULT_QUEUE_ID) if clear_result.deleted > 0: self.__invoker.services.logger.info(f"Cleared all {clear_result.deleted} queue items") + return + + if config.max_queue_history is not None: + deleted = self._prune_terminal_to_limit(DEFAULT_QUEUE_ID, config.max_queue_history) + if deleted > 0: + self.__invoker.services.logger.info( + f"Pruned {deleted} completed/failed/canceled queue items (kept up to {config.max_queue_history})" + ) def __init__(self, db: SqliteDatabase) -> None: super().__init__() @@ -68,6 +77,51 @@ class SqliteSessionQueue(SessionQueueBase): """ ) + def _prune_terminal_to_limit(self, queue_id: str, keep: int) -> int: + """Prune terminal items (completed/failed/canceled) to keep at most N most-recent items.""" + with self._db.transaction() as cursor: + where = """--sql + WHERE + queue_id = ? + AND ( + status = 'completed' + OR status = 'failed' + OR status = 'canceled' + ) + """ + cursor.execute( + f"""--sql + SELECT COUNT(*) + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + count = cursor.fetchone()[0] + cursor.execute( + f"""--sql + DELETE + FROM session_queue + {where} + AND item_id NOT IN ( + SELECT item_id + FROM session_queue + {where} + ORDER BY COALESCE(completed_at, updated_at, created_at) DESC, item_id DESC + LIMIT ? + ); + """, + (queue_id, queue_id, keep), + ) + return count + def _get_current_queue_size(self, queue_id: str) -> int: """Gets the current number of pending queue items""" with self._db.transaction() as cursor: diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 7fad98531e..cd7287d46e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1707,6 +1707,8 @@ "enableNSFWChecker": "Enable NSFW Checker", "general": "General", "generation": "Generation", + "maxQueueHistory": "Max Queue History", + "maxQueueHistorySaveFailed": "Failed to save Max Queue History", "models": "Models", "preferAttentionStyleNumeric": "Prefer Numeric Attention Style", "prompt": "Prompt", diff --git a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx index b95d2adb47..7bcc103402 100644 --- a/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx +++ b/invokeai/frontend/web/src/features/system/components/SettingsModal/SettingsModal.tsx @@ -11,6 +11,8 @@ import { ModalFooter, ModalHeader, ModalOverlay, + NumberInput, + NumberInputField, Switch, Text, } from '@invoke-ai/ui-library'; @@ -19,6 +21,7 @@ import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { InformationalPopover } from 'common/components/InformationalPopover/InformationalPopover'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import { buildUseBoolean } from 'common/hooks/useBoolean'; +import { selectCurrentUser } from 'features/auth/store/authSlice'; import { selectShouldUseCPUNoise, shouldUseCpuNoiseChanged } from 'features/controlLayers/store/paramsSlice'; import { useRefreshAfterResetModal } from 'features/system/components/SettingsModal/RefreshAfterResetModal'; import { SettingsDeveloperLogIsEnabled } from 'features/system/components/SettingsModal/SettingsDeveloperLogIsEnabled'; @@ -48,16 +51,25 @@ import { shouldUseNSFWCheckerChanged, shouldUseWatermarkerChanged, } from 'features/system/store/systemSlice'; +import { toast } from 'features/toast/toast'; import { selectShouldShowProgressInViewer } from 'features/ui/store/uiSelectors'; import { setShouldShowProgressInViewer } from 'features/ui/store/uiSlice'; -import type { ChangeEvent, ReactElement } from 'react'; -import { cloneElement, memo, useCallback, useEffect } from 'react'; +import type { ChangeEvent, KeyboardEvent, ReactElement } from 'react'; +import { cloneElement, memo, useCallback, useEffect, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; +import { useGetRuntimeConfigQuery, useUpdateRuntimeConfigMutation } from 'services/api/endpoints/appInfo'; import { SettingsLanguageSelect } from './SettingsLanguageSelect'; const [useSettingsModal] = buildUseBoolean(false); +const formatOptionalInteger = (value: number | null | undefined) => { + if (value === null || value === undefined) { + return ''; + } + return String(value); +}; + const SettingsModal = (props: { children: ReactElement }) => { const dispatch = useAppDispatch(); const { t } = useTranslation(); @@ -72,6 +84,10 @@ const SettingsModal = (props: { children: ReactElement }) => { const settingsModal = useSettingsModal(); const refreshModal = useRefreshAfterResetModal(); + const currentUser = useAppSelector(selectCurrentUser); + const { data: runtimeConfig } = useGetRuntimeConfigQuery(); + const [updateRuntimeConfig, { isLoading: isUpdatingRuntimeConfig }] = useUpdateRuntimeConfigMutation(); + const pendingMaxQueueHistoryRef = useRef(undefined); const prefersNumericAttentionWeights = useAppSelector(selectSystemPrefersNumericAttentionWeights); const shouldUseCpuNoise = useAppSelector(selectShouldUseCPUNoise); @@ -85,6 +101,10 @@ const SettingsModal = (props: { children: ReactElement }) => { const shouldHighlightFocusedRegions = useAppSelector(selectSystemShouldEnableHighlightFocusedRegions); const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession); const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail); + const maxQueueHistory = runtimeConfig?.config.max_queue_history ?? null; + const canEditRuntimeConfig = runtimeConfig ? !runtimeConfig.config.multiuser || currentUser?.is_admin : false; + const [maxQueueHistoryInput, setMaxQueueHistoryInput] = useState(formatOptionalInteger(maxQueueHistory)); + const onToggleConfirmOnNewSession = useCallback(() => { dispatch(shouldConfirmOnNewSessionToggled()); }, [dispatch]); @@ -96,11 +116,60 @@ const SettingsModal = (props: { children: ReactElement }) => { } }, [refetchIntermediatesCount, settingsModal.isTrue]); + useEffect(() => { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + }, [maxQueueHistory]); + + const commitMaxQueueHistory = useCallback(async () => { + if (!runtimeConfig || !canEditRuntimeConfig) { + return; + } + + const trimmedValue = maxQueueHistoryInput.trim(); + const parsedValue = trimmedValue === '' ? null : Number.parseInt(trimmedValue, 10); + + if (parsedValue !== null && Number.isNaN(parsedValue)) { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + return; + } + + const normalizedValue = parsedValue === null ? null : Math.max(0, parsedValue); + const currentValue = + pendingMaxQueueHistoryRef.current === undefined ? maxQueueHistory : pendingMaxQueueHistoryRef.current; + + if (normalizedValue === currentValue) { + setMaxQueueHistoryInput(formatOptionalInteger(currentValue)); + return; + } + + pendingMaxQueueHistoryRef.current = normalizedValue; + setMaxQueueHistoryInput(formatOptionalInteger(normalizedValue)); + + try { + await updateRuntimeConfig({ max_queue_history: normalizedValue }).unwrap(); + } catch { + setMaxQueueHistoryInput(formatOptionalInteger(maxQueueHistory)); + toast({ + id: 'SETTINGS_MAX_QUEUE_HISTORY_SAVE_FAILED', + title: t('settings.maxQueueHistorySaveFailed'), + status: 'error', + }); + } finally { + pendingMaxQueueHistoryRef.current = undefined; + } + }, [canEditRuntimeConfig, maxQueueHistory, maxQueueHistoryInput, runtimeConfig, t, updateRuntimeConfig]); + + const handleCloseSettingsModal = useCallback(() => { + void commitMaxQueueHistory(); + settingsModal.setFalse(); + }, [commitMaxQueueHistory, settingsModal]); + const handleClickResetWebUI = useCallback(() => { + void commitMaxQueueHistory(); clearStorage(); settingsModal.setFalse(); refreshModal.setTrue(); - }, [settingsModal, refreshModal]); + }, [commitMaxQueueHistory, refreshModal, settingsModal]); const handleChangeShouldConfirmOnDelete = useCallback( (e: ChangeEvent) => { @@ -172,12 +241,30 @@ const SettingsModal = (props: { children: ReactElement }) => { [dispatch] ); + const handleChangeMaxQueueHistory = useCallback((valueAsString: string) => { + setMaxQueueHistoryInput(valueAsString); + }, []); + + const handleBlurMaxQueueHistory = useCallback(() => { + void commitMaxQueueHistory(); + }, [commitMaxQueueHistory]); + + const handleKeyDownMaxQueueHistory = useCallback( + (e: KeyboardEvent) => { + if (e.key === 'Enter') { + void commitMaxQueueHistory(); + e.currentTarget.blur(); + } + }, + [commitMaxQueueHistory] + ); + return ( <> {cloneElement(props.children, { onClick: settingsModal.setTrue, })} - + {t('common.settingsLabel')} @@ -206,6 +293,21 @@ const SettingsModal = (props: { children: ReactElement }) => { {t('settings.enableInvisibleWatermark')} + + {t('settings.maxQueueHistory')} + + + + diff --git a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts index 8fe85125e6..1c656d289b 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/appInfo.ts @@ -51,6 +51,26 @@ export const appInfoApi = api.injectEndpoints({ url: buildAppInfoUrl('runtime_config'), method: 'GET', }), + providesTags: ['AppConfig'], + }), + updateRuntimeConfig: build.mutation< + paths['/api/v1/app/runtime_config']['patch']['responses']['200']['content']['application/json'], + paths['/api/v1/app/runtime_config']['patch']['requestBody']['content']['application/json'] + >({ + query: (body) => ({ + url: buildAppInfoUrl('runtime_config'), + method: 'PATCH', + body, + }), + async onQueryStarted(_, { dispatch, queryFulfilled }) { + try { + const { data } = await queryFulfilled; + dispatch(appInfoApi.util.upsertQueryData('getRuntimeConfig', undefined, data)); + } catch { + // no-op + } + }, + invalidatesTags: ['AppConfig'], }), getInvocationCacheStatus: build.query< paths['/api/v1/app/invocation_cache/status']['get']['responses']['200']['content']['application/json'], @@ -95,6 +115,7 @@ export const { useGetAppDepsQuery, useGetPatchmatchStatusQuery, useGetRuntimeConfigQuery, + useUpdateRuntimeConfigMutation, useClearInvocationCacheMutation, useDisableInvocationCacheMutation, useEnableInvocationCacheMutation, diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 4f10cf6c48..2e93e98ad5 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -1586,7 +1586,8 @@ export type paths = { delete?: never; options?: never; head?: never; - patch?: never; + /** Update Runtime Config */ + patch: operations["update_runtime_config"]; trace?: never; }; "/api/v1/app/logging": { @@ -15199,7 +15200,8 @@ export type components = { * force_tiled_decode: Whether to enable tiled VAE decode (reduces memory consumption with some performance penalty). * pil_compress_level: The compress_level setting of PIL.Image.save(), used for PNG encoding. All settings are lossless. 0 = no compression, 1 = fastest with slightly larger filesize, 9 = slowest with smallest filesize. 1 is typically the best setting. * max_queue_size: Maximum number of items in the session queue. - * clear_queue_on_startup: Empties session queue on startup. + * clear_queue_on_startup: Empties session queue on startup. If true, disables `max_queue_history`. + * max_queue_history: Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. * allow_nodes: List of nodes to allow. Omit to allow all. * deny_nodes: List of nodes to deny. Omit to deny none. * node_cache_size: How many cached nodes to keep in memory. @@ -15527,10 +15529,15 @@ export type components = { max_queue_size?: number; /** * Clear Queue On Startup - * @description Empties session queue on startup. + * @description Empties session queue on startup. If true, disables `max_queue_history`. * @default false */ clear_queue_on_startup?: boolean; + /** + * Max Queue History + * @description Keep the last N completed, failed, and canceled queue items. Older items are deleted on startup. Set to 0 to prune all terminal items. Ignored if `clear_queue_on_startup` is true. + */ + max_queue_history?: number | null; /** * Allow Nodes * @description List of nodes to allow. Omit to allow all. @@ -28585,6 +28592,17 @@ export type components = { */ unstarred_images: string[]; }; + /** + * UpdateAppGenerationSettingsRequest + * @description Writable generation-related app settings. + */ + UpdateAppGenerationSettingsRequest: { + /** + * Max Queue History + * @description Keep the last N completed, failed, and canceled queue items on startup. Set to 0 to prune all terminal items. + */ + max_queue_history?: number | null; + }; /** * UserDTO * @description User data transfer object. @@ -33751,6 +33769,39 @@ export interface operations { }; }; }; + update_runtime_config: { + parameters: { + query?: never; + header?: never; + path?: never; + cookie?: never; + }; + requestBody: { + content: { + "application/json": components["schemas"]["UpdateAppGenerationSettingsRequest"]; + }; + }; + responses: { + /** @description Successful Response */ + 200: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["InvokeAIAppConfigWithSetFields"]; + }; + }; + /** @description Validation Error */ + 422: { + headers: { + [name: string]: unknown; + }; + content: { + "application/json": components["schemas"]["HTTPValidationError"]; + }; + }; + }; + }; get_log_level: { parameters: { query?: never; From 441821ca0308546736c25806ba8e510b12d7411f Mon Sep 17 00:00:00 2001 From: Valeri Che <38873282+DustyShoe@users.noreply.github.com> Date: Tue, 14 Apr 2026 03:22:34 +0300 Subject: [PATCH 3/6] Feat(canvas): Add Lasso Tool with Freehand and Polygon modes (#8908) * Feat(Canvas): Add Lasso tool with Freehand and Polygon modes * Refine Lasso modes behavior and optimisation. * Fix: Pettier * added docs/features/Lasso_tool.md * Fix: Removed restrictions mentioned in PR's conversation: 1. Disabled when there is no visible raster content 2. Lasso is blocked when all inpaint masks are globally hidden. --------- Co-authored-by: dunkeroni --- docs/features/Lasso_tool.md | 32 + invokeai/frontend/web/public/locales/en.json | 10 + .../components/Tool/ToolChooser.tsx | 2 + .../components/Tool/ToolLassoButton.tsx | 34 ++ .../components/Tool/ToolLassoModeToggle.tsx | 47 ++ .../components/Toolbar/CanvasToolbar.tsx | 7 + .../CanvasEntityBufferObjectRenderer.ts | 13 + .../CanvasEntityObjectRenderer.ts | 19 +- .../konva/CanvasObject/CanvasObjectLasso.ts | 85 +++ .../controlLayers/konva/CanvasObject/types.ts | 4 + .../konva/CanvasStateApiModule.ts | 9 + .../konva/CanvasTool/CanvasLassoToolModule.ts | 566 ++++++++++++++++++ .../konva/CanvasTool/CanvasToolModule.ts | 118 +++- .../store/canvasSettingsSlice.ts | 11 + .../controlLayers/store/canvasSlice.ts | 37 +- .../features/controlLayers/store/selectors.ts | 6 +- .../src/features/controlLayers/store/types.ts | 18 +- .../components/HotkeysModal/useHotkeyData.ts | 1 + 18 files changed, 1005 insertions(+), 14 deletions(-) create mode 100644 docs/features/Lasso_tool.md create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts diff --git a/docs/features/Lasso_tool.md b/docs/features/Lasso_tool.md new file mode 100644 index 0000000000..8f7fc6d4ec --- /dev/null +++ b/docs/features/Lasso_tool.md @@ -0,0 +1,32 @@ +Lasso Tool +=========== + +- The Lasso tool creates selections and inpaint masks by drawing freehand or polygonal regions on the canvas. + +How to open the Lasso tool +-------------------------- +- Click the Lasso icon in the toolbar. +- Hotkey: press `L` (default). The hotkey is shown in the tool's tooltip and can be customized in Hotkeys settings. + +Modes +----- +- Freehand (default) + - Hold the pointer and drag to draw a continuous contour. + - Long segments are broken into intermediate points to keep the line continuous. + - Very long strokes may be simplified after drawing to reduce point count for performance. + +- Polygon + - Click to place points; click the first point (or a point near it) to close the polygon. + - The tool snaps the closing point to the start for precise closures. + +Basic interactions +------------------ +- Switch modes with the mode toggle in the toolbar. +- To close a polygon: click the starting point again or click near it — the tool aligns the final point to the start to complete the shape. +- The selection will be added to the current Inpaint Mask layer. If no Inpaint Mask layer exists, a new one will be created automatically. + +Tips & behavior +--------------- +- Hold `Space` to temporarily switch to the View tool for panning and zooming; release `Space` to return to the Lasso tool and continue drawing. +- When using the Polygon mode, you can hold `Shift` to snap points to horizontal, vertical, or 45-degree angles for more precise shapes. +- Hold `Ctrl` (Windows/Linux) or `Command` (macOS) while drawing to subtract from the current selection instead of adding to it. diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index cd7287d46e..0e2f982ea6 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -715,6 +715,10 @@ "title": "Rect Tool", "desc": "Select the rect tool." }, + "selectLassoTool": { + "title": "Lasso Tool", + "desc": "Select the lasso tool." + }, "selectViewTool": { "title": "View Tool", "desc": "Select the view tool." @@ -2762,10 +2766,16 @@ "radial": "Radial", "clip": "Clip Gradient" }, + "lasso": { + "freehand": "Freehand", + "polygon": "Polygon", + "polygonHint": "Click to add points, click the first point to close." + }, "tool": { "brush": "Brush", "eraser": "Eraser", "rectangle": "Rectangle", + "lasso": "Lasso", "gradient": "Gradient", "bbox": "Bbox", "move": "Move", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx index 44efe12eb9..30d8272207 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolChooser.tsx @@ -3,6 +3,7 @@ import { ToolBboxButton } from 'features/controlLayers/components/Tool/ToolBboxB import { ToolBrushButton } from 'features/controlLayers/components/Tool/ToolBrushButton'; import { ToolColorPickerButton } from 'features/controlLayers/components/Tool/ToolColorPickerButton'; import { ToolGradientButton } from 'features/controlLayers/components/Tool/ToolGradientButton'; +import { ToolLassoButton } from 'features/controlLayers/components/Tool/ToolLassoButton'; import { ToolMoveButton } from 'features/controlLayers/components/Tool/ToolMoveButton'; import { ToolRectButton } from 'features/controlLayers/components/Tool/ToolRectButton'; import { ToolTextButton } from 'features/controlLayers/components/Tool/ToolTextButton'; @@ -20,6 +21,7 @@ export const ToolChooser: React.FC = () => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx new file mode 100644 index 0000000000..587e8a7223 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoButton.tsx @@ -0,0 +1,34 @@ +import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useSelectTool, useToolIsSelected } from 'features/controlLayers/components/Tool/hooks'; +import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiLassoBold } from 'react-icons/pi'; + +export const ToolLassoButton = memo(() => { + const { t } = useTranslation(); + const isSelected = useToolIsSelected('lasso'); + const selectLasso = useSelectTool('lasso'); + + useRegisteredHotkeys({ + id: 'selectLassoTool', + category: 'canvas', + callback: selectLasso, + options: { enabled: !isSelected }, + dependencies: [isSelected, selectLasso], + }); + + return ( + + } + colorScheme={isSelected ? 'invokeBlue' : 'base'} + variant="solid" + onClick={selectLasso} + /> + + ); +}); + +ToolLassoButton.displayName = 'ToolLassoButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx new file mode 100644 index 0000000000..63aa27c609 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Tool/ToolLassoModeToggle.tsx @@ -0,0 +1,47 @@ +import { ButtonGroup, IconButton, Tooltip } from '@invoke-ai/ui-library'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { selectLassoMode, settingsLassoModeChanged } from 'features/controlLayers/store/canvasSettingsSlice'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiPolygonBold, PiScribbleLoopBold } from 'react-icons/pi'; + +export const ToolLassoModeToggle = memo(() => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + const lassoMode = useAppSelector(selectLassoMode); + + const setFreehand = useCallback(() => { + dispatch(settingsLassoModeChanged('freehand')); + }, [dispatch]); + + const setPolygon = useCallback(() => { + dispatch(settingsLassoModeChanged('polygon')); + }, [dispatch]); + + return ( + + + } + colorScheme={lassoMode === 'freehand' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={setFreehand} + /> + + + } + colorScheme={lassoMode === 'polygon' ? 'invokeBlue' : 'base'} + variant="solid" + onClick={setPolygon} + /> + + + ); +}); + +ToolLassoModeToggle.displayName = 'ToolLassoModeToggle'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index bf186ed630..faea5d98c3 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -5,6 +5,7 @@ import { useToolIsSelected } from 'features/controlLayers/components/Tool/hooks' import { ToolFillColorPicker } from 'features/controlLayers/components/Tool/ToolFillColorPicker'; import { ToolGradientClipToggle } from 'features/controlLayers/components/Tool/ToolGradientClipToggle'; import { ToolGradientModeToggle } from 'features/controlLayers/components/Tool/ToolGradientModeToggle'; +import { ToolLassoModeToggle } from 'features/controlLayers/components/Tool/ToolLassoModeToggle'; import { ToolOptionsRowContainer } from 'features/controlLayers/components/Tool/ToolOptionsRowContainer'; import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidthPicker'; import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton'; @@ -31,6 +32,7 @@ export const CanvasToolbar = memo(() => { const isBrushSelected = useToolIsSelected('brush'); const isEraserSelected = useToolIsSelected('eraser'); const isTextSelected = useToolIsSelected('text'); + const isLassoSelected = useToolIsSelected('lasso'); const isGradientSelected = useToolIsSelected('gradient'); const showToolWithPicker = useMemo(() => { return !isTextSelected && (isBrushSelected || isEraserSelected); @@ -57,6 +59,11 @@ export const CanvasToolbar = memo(() => { )} + {isLassoSelected && ( + + + + )} {isTextSelected ? : showToolWithPicker && } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts index ef5cee8d89..9941761a2e 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer.ts @@ -8,6 +8,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { getPrefixedId } from 'features/controlLayers/konva/util'; @@ -152,6 +153,15 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { this.konva.group.add(this.renderer.konva.group); } + didRender = this.renderer.update(this.state, true); + } else if (this.state.type === 'lasso') { + assert(this.renderer instanceof CanvasObjectLasso || !this.renderer); + + if (!this.renderer) { + this.renderer = new CanvasObjectLasso(this.state, this); + this.konva.group.add(this.renderer.konva.group); + } + didRender = this.renderer.update(this.state, true); } else if (this.state.type === 'gradient') { assert(this.renderer instanceof CanvasObjectGradient || !this.renderer); @@ -247,6 +257,9 @@ export class CanvasEntityBufferObjectRenderer extends CanvasModuleBase { case 'rect': this.manager.stateApi.addRect({ entityIdentifier, rect: this.state }); break; + case 'lasso': + this.manager.stateApi.addLasso({ entityIdentifier, lasso: this.state }); + break; case 'gradient': this.manager.stateApi.addGradient({ entityIdentifier, gradient: this.state }); break; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts index 1498a6cbb5..903ccaa772 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer.ts @@ -10,6 +10,7 @@ import { CanvasObjectEraserLine } from 'features/controlLayers/konva/CanvasObjec import { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { AnyObjectRenderer, AnyObjectState } from 'features/controlLayers/konva/CanvasObject/types'; import { LightnessToAlphaFilter } from 'features/controlLayers/konva/filters'; @@ -397,6 +398,16 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { this.konva.objectGroup.add(renderer.konva.group); } + didRender = renderer.update(objectState, force || isFirstRender); + } else if (objectState.type === 'lasso') { + assert(renderer instanceof CanvasObjectLasso || !renderer); + + if (!renderer) { + renderer = new CanvasObjectLasso(objectState, this); + this.renderers.set(renderer.id, renderer); + this.konva.objectGroup.add(renderer.konva.group); + } + didRender = renderer.update(objectState, force || isFirstRender); } else if (objectState.type === 'gradient') { assert(renderer instanceof CanvasObjectGradient || !renderer); @@ -433,17 +444,21 @@ export class CanvasEntityObjectRenderer extends CanvasModuleBase { * these visually transparent shapes in its calculation: * * - Eraser lines, which are normal lines with a globalCompositeOperation of 'destination-out'. + * - Subtracting lasso shapes, which use a globalCompositeOperation of 'destination-out'. * - Clipped portions of any shape. * - Images, which may have transparent areas. */ needsPixelBbox = (): boolean => { let needsPixelBbox = false; for (const renderer of this.renderers.values()) { - const isEraserLine = renderer instanceof CanvasObjectEraserLine; + const isEraserLine = + renderer instanceof CanvasObjectEraserLine || renderer instanceof CanvasObjectEraserLineWithPressure; + const isSubtractingLasso = + renderer instanceof CanvasObjectLasso && renderer.state.compositeOperation === 'destination-out'; const isImage = renderer instanceof CanvasObjectImage; const imageIgnoresTransparency = isImage && renderer.state.usePixelBbox === false; const hasClip = renderer instanceof CanvasObjectBrushLine && renderer.state.clip; - if (isEraserLine || hasClip || (isImage && !imageIgnoresTransparency)) { + if (isEraserLine || isSubtractingLasso || hasClip || (isImage && !imageIgnoresTransparency)) { needsPixelBbox = true; break; } diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts new file mode 100644 index 0000000000..ad433a9f67 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/CanvasObjectLasso.ts @@ -0,0 +1,85 @@ +import { deepClone } from 'common/util/deepClone'; +import type { CanvasEntityBufferObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityBufferObjectRenderer'; +import type { CanvasEntityObjectRenderer } from 'features/controlLayers/konva/CanvasEntity/CanvasEntityObjectRenderer'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasLassoState } from 'features/controlLayers/store/types'; +import Konva from 'konva'; +import type { Logger } from 'roarr'; + +export class CanvasObjectLasso extends CanvasModuleBase { + readonly type = 'object_lasso'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer; + readonly manager: CanvasManager; + readonly log: Logger; + + state: CanvasLassoState; + konva: { + group: Konva.Group; + line: Konva.Line; + }; + + constructor(state: CanvasLassoState, parent: CanvasEntityObjectRenderer | CanvasEntityBufferObjectRenderer) { + super(); + this.id = state.id; + this.parent = parent; + this.manager = parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + + this.log.debug({ state }, 'Creating module'); + + this.konva = { + group: new Konva.Group({ + name: `${this.type}:group`, + listening: false, + }), + line: new Konva.Line({ + name: `${this.type}:line`, + listening: false, + closed: true, + fill: 'white', + strokeEnabled: false, + perfectDrawEnabled: false, + }), + }; + this.konva.group.add(this.konva.line); + this.state = state; + } + + update(state: CanvasLassoState, force = false): boolean { + if (force || this.state !== state) { + this.log.trace({ state }, 'Updating lasso'); + this.konva.line.setAttrs({ + points: state.points, + globalCompositeOperation: state.compositeOperation, + }); + this.state = state; + return true; + } + + return false; + } + + setVisibility(isVisible: boolean): void { + this.log.trace({ isVisible }, 'Setting lasso visibility'); + this.konva.group.visible(isVisible); + } + + destroy = () => { + this.log.debug('Destroying module'); + this.konva.group.destroy(); + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + parent: this.parent.id, + state: deepClone(this.state), + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts index 31a2bfee07..f193c0b391 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasObject/types.ts @@ -4,6 +4,7 @@ import type { CanvasObjectEraserLine } from 'features/controlLayers/konva/Canvas import type { CanvasObjectEraserLineWithPressure } from 'features/controlLayers/konva/CanvasObject/CanvasObjectEraserLineWithPressure'; import type { CanvasObjectGradient } from 'features/controlLayers/konva/CanvasObject/CanvasObjectGradient'; import type { CanvasObjectImage } from 'features/controlLayers/konva/CanvasObject/CanvasObjectImage'; +import type { CanvasObjectLasso } from 'features/controlLayers/konva/CanvasObject/CanvasObjectLasso'; import type { CanvasObjectRect } from 'features/controlLayers/konva/CanvasObject/CanvasObjectRect'; import type { CanvasBrushLineState, @@ -12,6 +13,7 @@ import type { CanvasEraserLineWithPressureState, CanvasGradientState, CanvasImageState, + CanvasLassoState, CanvasRectState, } from 'features/controlLayers/store/types'; @@ -25,6 +27,7 @@ export type AnyObjectRenderer = | CanvasObjectEraserLine | CanvasObjectEraserLineWithPressure | CanvasObjectRect + | CanvasObjectLasso | CanvasObjectImage | CanvasObjectGradient; /** @@ -37,4 +40,5 @@ export type AnyObjectState = | CanvasEraserLineWithPressureState | CanvasImageState | CanvasRectState + | CanvasLassoState | CanvasGradientState; diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts index e00ade1f8b..7d4c76b0c0 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasStateApiModule.ts @@ -21,6 +21,7 @@ import { entityBrushLineAdded, entityEraserLineAdded, entityGradientAdded, + entityLassoAdded, entityMovedBy, entityMovedTo, entityRasterized, @@ -43,6 +44,7 @@ import type { EntityEraserLineAddedPayload, EntityGradientAddedPayload, EntityIdentifierPayload, + EntityLassoAddedPayload, EntityMovedByPayload, EntityMovedToPayload, EntityRasterizedPayload, @@ -175,6 +177,13 @@ export class CanvasStateApiModule extends CanvasModuleBase { this.store.dispatch(entityRectAdded(arg)); }; + /** + * Adds a lasso object to an entity, pushing state to redux. + */ + addLasso = (arg: EntityLassoAddedPayload) => { + this.store.dispatch(entityLassoAdded(arg)); + }; + /** * Adds a gradient to an entity, pushing state to redux. */ diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts new file mode 100644 index 0000000000..12f2638abc --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasLassoToolModule.ts @@ -0,0 +1,566 @@ +import { rgbaColorToString } from 'common/util/colorCodeTransformers'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { CanvasModuleBase } from 'features/controlLayers/konva/CanvasModuleBase'; +import type { CanvasToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasToolModule'; +import { getPrefixedId, isDistanceMoreThanMin, offsetCoord } from 'features/controlLayers/konva/util'; +import type { Coordinate } from 'features/controlLayers/store/types'; +import { simplifyFlatNumbersArray } from 'features/controlLayers/util/simplify'; +import Konva from 'konva'; +import type { KonvaEventObject } from 'konva/lib/Node'; +import type { Logger } from 'roarr'; + +type CanvasLassoToolModuleConfig = { + PREVIEW_STROKE_COLOR: string; + PREVIEW_FILL_COLOR: string; + PREVIEW_STROKE_WIDTH_PX: number; + START_POINT_RADIUS_PX: number; + START_POINT_STROKE_WIDTH_PX: number; + START_POINT_HOVER_RADIUS_DELTA_PX: number; + POLYGON_CLOSE_RADIUS_PX: number; + MIN_FREEHAND_POINT_DISTANCE_PX: number; + MAX_FREEHAND_SEGMENT_LENGTH_PX: number; + FREEHAND_SIMPLIFY_MIN_POINTS: number; + FREEHAND_SIMPLIFY_TOLERANCE: number; +}; + +const DEFAULT_CONFIG: CanvasLassoToolModuleConfig = { + PREVIEW_STROKE_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 1 }), + PREVIEW_FILL_COLOR: rgbaColorToString({ r: 90, g: 175, b: 255, a: 0.2 }), + PREVIEW_STROKE_WIDTH_PX: 1.5, + START_POINT_RADIUS_PX: 4, + START_POINT_STROKE_WIDTH_PX: 2, + START_POINT_HOVER_RADIUS_DELTA_PX: 2, + POLYGON_CLOSE_RADIUS_PX: 10, + MIN_FREEHAND_POINT_DISTANCE_PX: 1, + MAX_FREEHAND_SEGMENT_LENGTH_PX: 2, + FREEHAND_SIMPLIFY_MIN_POINTS: 200, + FREEHAND_SIMPLIFY_TOLERANCE: 0.6, +}; + +export class CanvasLassoToolModule extends CanvasModuleBase { + readonly type = 'lasso_tool'; + readonly id: string; + readonly path: string[]; + readonly parent: CanvasToolModule; + readonly manager: CanvasManager; + readonly log: Logger; + + config: CanvasLassoToolModuleConfig = DEFAULT_CONFIG; + + private freehandPoints: Coordinate[] = []; + private polygonPoints: Coordinate[] = []; + private polygonPointer: Coordinate | null = null; + private isDrawingFreehand = false; + + konva: { + group: Konva.Group; + fillShape: Konva.Line; + strokeShape: Konva.Line; + startPointIndicator: Konva.Circle; + }; + + constructor(parent: CanvasToolModule) { + super(); + this.id = getPrefixedId(this.type); + this.parent = parent; + this.manager = this.parent.manager; + this.path = this.manager.buildPath(this); + this.log = this.manager.buildLogger(this); + this.log.debug('Creating module'); + + this.konva = { + group: new Konva.Group({ name: `${this.type}:group`, listening: false }), + fillShape: new Konva.Line({ + name: `${this.type}:fill_shape`, + listening: false, + closed: true, + fill: this.config.PREVIEW_FILL_COLOR, + strokeEnabled: false, + visible: false, + perfectDrawEnabled: false, + }), + strokeShape: new Konva.Line({ + name: `${this.type}:stroke_shape`, + listening: false, + closed: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + strokeWidth: this.config.PREVIEW_STROKE_WIDTH_PX, + lineCap: 'round', + lineJoin: 'round', + fillEnabled: false, + visible: false, + perfectDrawEnabled: false, + }), + startPointIndicator: new Konva.Circle({ + name: `${this.type}:start_point_indicator`, + listening: false, + fillEnabled: false, + stroke: this.config.PREVIEW_STROKE_COLOR, + visible: false, + perfectDrawEnabled: false, + }), + }; + + this.konva.group.add(this.konva.fillShape); + this.konva.group.add(this.konva.strokeShape); + this.konva.group.add(this.konva.startPointIndicator); + } + + syncCursorStyle = () => { + if (!this.parent.getCanDraw()) { + this.manager.stage.setCursor('not-allowed'); + return; + } + this.manager.stage.setCursor('crosshair'); + }; + + render = () => { + const tool = this.parent.$tool.get(); + const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso'; + if (tool !== 'lasso' && !isTemporaryViewSwitch) { + this.hidePreview(); + return; + } + + if (tool === 'lasso') { + this.syncCursorStyle(); + } + this.syncPreview(); + }; + + onToolChanged = () => { + const tool = this.parent.$tool.get(); + const isTemporaryViewSwitch = tool === 'view' && this.parent.$toolBuffer.get() === 'lasso'; + if (tool !== 'lasso' && !isTemporaryViewSwitch) { + this.reset(); + } + }; + + hasActiveSession = (): boolean => { + return this.isDrawingFreehand || this.freehandPoints.length > 0 || this.polygonPoints.length > 0; + }; + + onStagePointerDown = (e: KonvaEventObject) => { + const cursorPos = this.parent.$cursorPos.get(); + if (!cursorPos) { + return; + } + + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const point = cursorPos.relative; + + // Keep middle click for pan and right click for context menu. + if (e.evt.button !== 0) { + return; + } + + if (lassoMode === 'freehand') { + if (!this.parent.$isPrimaryPointerDown.get()) { + return; + } + + this.polygonPoints = []; + this.polygonPointer = null; + this.freehandPoints = [point]; + this.isDrawingFreehand = true; + this.syncPreview(); + return; + } + + this.freehandPoints = []; + this.isDrawingFreehand = false; + + if (this.polygonPoints.length === 0) { + this.polygonPoints = [point]; + this.polygonPointer = point; + this.syncPreview(); + return; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return; + } + + if ( + this.polygonPoints.length >= 3 && + Math.hypot(point.x - startPoint.x, point.y - startPoint.y) <= this.getPolygonCloseRadius() + ) { + this.commitContour(this.polygonPoints); + this.reset(); + return; + } + + const snappedPoint = this.getPolygonPoint(point, e.evt.shiftKey); + this.polygonPoints = [...this.polygonPoints, snappedPoint]; + this.polygonPointer = snappedPoint; + this.syncPreview(); + }; + + onStagePointerMove = (_e: KonvaEventObject) => { + this.handlePointerMove(_e.evt.shiftKey); + }; + + onWindowPointerMove = (e: PointerEvent) => { + this.handlePointerMove(e.shiftKey); + }; + + onStagePointerUp = (_e: KonvaEventObject) => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + if (lassoMode !== 'freehand' || !this.isDrawingFreehand) { + return; + } + + this.commitContour(this.freehandPoints, true); + this.reset(); + }; + + onWindowPointerUp = () => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + if (lassoMode !== 'freehand' || !this.isDrawingFreehand) { + return; + } + + this.commitContour(this.freehandPoints, true); + this.reset(); + }; + + reset = () => { + this.freehandPoints = []; + this.polygonPoints = []; + this.polygonPointer = null; + this.isDrawingFreehand = false; + this.hidePreview(); + }; + + private handlePointerMove = (shouldSnap: boolean) => { + const cursorPos = this.parent.$cursorPos.get(); + if (!cursorPos) { + return; + } + + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const point = cursorPos.relative; + + if (lassoMode === 'freehand') { + if (!this.isDrawingFreehand || !this.parent.$isPrimaryPointerDown.get()) { + return; + } + + const minDistance = this.manager.stage.unscale(this.config.MIN_FREEHAND_POINT_DISTANCE_PX); + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!isDistanceMoreThanMin(point, lastPoint, minDistance)) { + return; + } + this.appendFreehandPoint(point); + this.syncPreview(); + return; + } + + if (this.polygonPoints.length > 0) { + this.polygonPointer = this.getPolygonPoint(point, shouldSnap); + this.syncPreview(); + } + }; + + private appendFreehandPoint = (point: Coordinate) => { + const lastPoint = this.freehandPoints.at(-1) ?? null; + if (!lastPoint) { + this.freehandPoints.push(point); + return; + } + + const maxSegmentLength = this.manager.stage.unscale(this.config.MAX_FREEHAND_SEGMENT_LENGTH_PX); + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance <= maxSegmentLength) { + this.freehandPoints.push(point); + return; + } + + const steps = Math.ceil(distance / maxSegmentLength); + for (let i = 1; i <= steps; i++) { + const t = i / steps; + this.freehandPoints.push({ + x: lastPoint.x + dx * t, + y: lastPoint.y + dy * t, + }); + } + }; + + private hidePreview = () => { + this.konva.strokeShape.visible(false); + this.konva.fillShape.visible(false); + this.konva.startPointIndicator.visible(false); + }; + + private syncPreview = () => { + const lassoMode = this.manager.stateApi.getSettings().lassoMode; + const stageScale = this.manager.stage.getScale(); + const strokeWidth = this.config.PREVIEW_STROKE_WIDTH_PX / stageScale; + + let points: Coordinate[] = []; + if (lassoMode === 'freehand') { + points = this.freehandPoints; + } else { + points = [...this.polygonPoints]; + if (this.polygonPointer) { + points.push(this.polygonPointer); + } + } + + if (points.length < 1) { + this.hidePreview(); + return; + } + + const flat = points.flatMap((point) => [point.x, point.y]); + this.konva.strokeShape.setAttrs({ + points: flat, + strokeWidth, + visible: true, + }); + + if (points.length >= 3) { + this.konva.fillShape.setAttrs({ + points: flat, + visible: true, + }); + } else { + this.konva.fillShape.visible(false); + } + + if (lassoMode === 'polygon' && this.polygonPoints.length > 0) { + const startPoint = this.polygonPoints[0]; + if (startPoint) { + const isHoveringStartPoint = this.getIsHoveringStartPoint(startPoint); + const baseRadius = this.manager.stage.unscale(this.config.START_POINT_RADIUS_PX); + this.konva.startPointIndicator.setAttrs({ + x: startPoint.x, + y: startPoint.y, + radius: + baseRadius + + (isHoveringStartPoint ? this.manager.stage.unscale(this.config.START_POINT_HOVER_RADIUS_DELTA_PX) : 0), + strokeWidth: this.manager.stage.unscale(this.config.START_POINT_STROKE_WIDTH_PX), + visible: true, + }); + } + } else { + this.konva.startPointIndicator.visible(false); + } + }; + + private getPolygonCloseRadius = (): number => { + return this.manager.stage.unscale(this.config.POLYGON_CLOSE_RADIUS_PX); + }; + + private getIsHoveringStartPoint = (startPoint: Coordinate): boolean => { + if (this.polygonPoints.length < 3) { + return false; + } + + const pointerPoint = this.parent.$cursorPos.get()?.relative; + if (!pointerPoint) { + return false; + } + + return Math.hypot(pointerPoint.x - startPoint.x, pointerPoint.y - startPoint.y) <= this.getPolygonCloseRadius(); + }; + + private getPolygonPoint = (point: Coordinate, shouldSnap: boolean): Coordinate => { + if (!shouldSnap) { + return point; + } + + const lastPoint = this.polygonPoints.at(-1); + if (!lastPoint) { + return point; + } + + const dx = point.x - lastPoint.x; + const dy = point.y - lastPoint.y; + const distance = Math.hypot(dx, dy); + + if (distance === 0) { + return point; + } + + const SNAP_ANGLE = Math.PI / 4; + const angle = Math.atan2(dy, dx); + const snappedAngle = Math.round(angle / SNAP_ANGLE) * SNAP_ANGLE; + + const snappedPoint = { + x: lastPoint.x + Math.cos(snappedAngle) * distance, + y: lastPoint.y + Math.sin(snappedAngle) * distance, + }; + + return this.alignPointToStart(snappedPoint); + }; + + private alignPointToStart = (point: Coordinate): Coordinate => { + if (this.polygonPoints.length < 2) { + return point; + } + + const startPoint = this.polygonPoints[0]; + if (!startPoint) { + return point; + } + + const alignThreshold = this.getPolygonCloseRadius(); + const deltaX = Math.abs(point.x - startPoint.x); + const deltaY = Math.abs(point.y - startPoint.y); + const canAlignX = deltaX <= alignThreshold; + const canAlignY = deltaY <= alignThreshold; + + if (!canAlignX && !canAlignY) { + return point; + } + + if (canAlignX && canAlignY) { + if (deltaX <= deltaY) { + return { x: startPoint.x, y: point.y }; + } + return { x: point.x, y: startPoint.y }; + } + + if (canAlignX) { + return { x: startPoint.x, y: point.y }; + } + + return { x: point.x, y: startPoint.y }; + }; + + private closeContour = (points: Coordinate[]): Coordinate[] => { + if (points.length === 0) { + return []; + } + + const start = points[0]; + const end = points.at(-1); + if (!start || !end) { + return points; + } + + if (start.x === end.x && start.y === end.y) { + return points; + } + + return [...points, start]; + }; + + private commitContour = (points: Coordinate[], simplifyFreehand: boolean = false) => { + const contourPoints = simplifyFreehand ? this.simplifyFreehandContour(points) : points; + if (contourPoints.length < 3) { + return; + } + + const closedPoints = this.closeContour(contourPoints); + if (closedPoints.length < 4) { + return; + } + + let targetMaskId = this.getActiveInpaintMaskId(); + if (!targetMaskId) { + this.manager.stateApi.addInpaintMask({ isSelected: true }); + targetMaskId = this.getActiveInpaintMaskId(); + } + + if (!targetMaskId) { + return; + } + + const targetMaskState = this.manager.stateApi + .getInpaintMasksState() + .entities.find((entity) => entity.id === targetMaskId); + if (!targetMaskState) { + return; + } + + const normalizedPoints = closedPoints.flatMap((point) => { + const normalizedPoint = offsetCoord(point, targetMaskState.position); + return [normalizedPoint.x, normalizedPoint.y]; + }); + + this.manager.stateApi.addLasso({ + entityIdentifier: { type: 'inpaint_mask', id: targetMaskId }, + lasso: { + id: getPrefixedId('lasso'), + type: 'lasso', + points: normalizedPoints, + compositeOperation: + this.manager.stateApi.$ctrlKey.get() || this.manager.stateApi.$metaKey.get() + ? 'destination-out' + : 'source-over', + }, + }); + }; + + private simplifyFreehandContour = (points: Coordinate[]): Coordinate[] => { + if (points.length < this.config.FREEHAND_SIMPLIFY_MIN_POINTS) { + return points; + } + + const flatPoints = points.flatMap((point) => [point.x, point.y]); + const simplifiedFlatPoints = simplifyFlatNumbersArray(flatPoints, { + tolerance: this.config.FREEHAND_SIMPLIFY_TOLERANCE, + highestQuality: true, + }); + if (simplifiedFlatPoints.length < 6) { + return points; + } + + const simplifiedPoints = this.flatNumbersToCoords(simplifiedFlatPoints); + if (simplifiedPoints.length < 3) { + return points; + } + + return simplifiedPoints; + }; + + private flatNumbersToCoords = (points: number[]): Coordinate[] => { + const coords: Coordinate[] = []; + for (let i = 0; i < points.length; i += 2) { + const x = points[i]; + const y = points[i + 1]; + if (x === undefined || y === undefined) { + continue; + } + coords.push({ x, y }); + } + return coords; + }; + + private getActiveInpaintMaskId = (): string | null => { + const canvasState = this.manager.stateApi.getCanvasState(); + const selectedEntityIdentifier = canvasState.selectedEntityIdentifier; + if (selectedEntityIdentifier?.type === 'inpaint_mask') { + const selectedMask = canvasState.inpaintMasks.entities.find( + (entity) => entity.id === selectedEntityIdentifier.id + ); + if (selectedMask?.isEnabled) { + return selectedMask.id; + } + // If the selected mask is disabled, commit to a new mask instead. + return null; + } + + const inpaintMasks = canvasState.inpaintMasks.entities; + const activeMask = [...inpaintMasks].reverse().find((entity) => entity.isEnabled); + return activeMask?.id ?? null; + }; + + repr = () => { + return { + id: this.id, + type: this.type, + path: this.path, + freehandPoints: this.freehandPoints, + polygonPoints: this.polygonPoints, + polygonPointer: this.polygonPointer, + isDrawingFreehand: this.isDrawingFreehand, + }; + }; +} diff --git a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts index c25a714bad..668ac7be3b 100644 --- a/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts +++ b/invokeai/frontend/web/src/features/controlLayers/konva/CanvasTool/CanvasToolModule.ts @@ -5,6 +5,7 @@ import { CanvasBrushToolModule } from 'features/controlLayers/konva/CanvasTool/C import { CanvasColorPickerToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasColorPickerToolModule'; import { CanvasEraserToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasEraserToolModule'; import { CanvasGradientToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasGradientToolModule'; +import { CanvasLassoToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasLassoToolModule'; import { CanvasMoveToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasMoveToolModule'; import { CanvasRectToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasRectToolModule'; import { CanvasTextToolModule } from 'features/controlLayers/konva/CanvasTool/CanvasTextToolModule'; @@ -38,6 +39,7 @@ Konva.dragButtons = [0]; const KEY_ESCAPE = 'Escape'; const KEY_SPACE = ' '; const KEY_ALT = 'Alt'; +const CODE_SPACE = 'Space'; type CanvasToolModuleConfig = { BRUSH_SPACING_TARGET_SCALE: number; @@ -62,6 +64,7 @@ export class CanvasToolModule extends CanvasModuleBase { brush: CanvasBrushToolModule; eraser: CanvasEraserToolModule; rect: CanvasRectToolModule; + lasso: CanvasLassoToolModule; gradient: CanvasGradientToolModule; colorPicker: CanvasColorPickerToolModule; bbox: CanvasBboxToolModule; @@ -121,6 +124,7 @@ export class CanvasToolModule extends CanvasModuleBase { brush: new CanvasBrushToolModule(this), eraser: new CanvasEraserToolModule(this), rect: new CanvasRectToolModule(this), + lasso: new CanvasLassoToolModule(this), gradient: new CanvasGradientToolModule(this), colorPicker: new CanvasColorPickerToolModule(this), bbox: new CanvasBboxToolModule(this), @@ -139,15 +143,26 @@ export class CanvasToolModule extends CanvasModuleBase { this.konva.group.add(this.tools.colorPicker.konva.group); this.konva.group.add(this.tools.text.konva.group); this.konva.group.add(this.tools.bbox.konva.group); + this.konva.group.add(this.tools.lasso.konva.group); this.subscriptions.add(this.manager.stage.$stageAttrs.listen(this.render)); this.subscriptions.add(this.manager.$isBusy.listen(this.render)); this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSettingsSlice, this.render)); this.subscriptions.add(this.manager.stateApi.createStoreSubscription(selectCanvasSlice, this.render)); this.subscriptions.add( - this.$tool.listen(() => { - // On tool switch, reset mouse state - this.manager.tool.$isPrimaryPointerDown.set(false); + this.$tool.listen((tool, previousTool) => { + // Preserve pointer state during temporary view switching so lasso sessions can freeze/resume on space. + const shouldPreservePointerState = + this.$toolBuffer.get() === 'lasso' && + this.tools.lasso.hasActiveSession() && + ((previousTool === 'lasso' && tool === 'view') || (previousTool === 'view' && tool === 'lasso')); + + if (!shouldPreservePointerState) { + // On tool switch, reset mouse state + this.manager.tool.$isPrimaryPointerDown.set(false); + } + + this.tools.lasso.onToolChanged(); void this.tools.text.onToolChanged(); this.render(); }) @@ -189,6 +204,8 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.colorPicker.syncCursorStyle(); } else if (tool === 'text') { this.tools.text.syncCursorStyle(); + } else if (tool === 'lasso') { + this.tools.lasso.syncCursorStyle(); } else if (selectedEntityAdapter) { if (selectedEntityAdapter.$isDisabled.get()) { stage.setCursor('not-allowed'); @@ -222,6 +239,7 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.colorPicker.render(); this.tools.text.render(); this.tools.bbox.render(); + this.tools.lasso.render(); }; syncCursorPositions = () => { @@ -235,6 +253,19 @@ export class CanvasToolModule extends CanvasModuleBase { this.$cursorPos.set({ relative, absolute }); }; + syncCursorPositionsFromWindowEvent = (e: PointerEvent): boolean => { + this.konva.stage.setPointersPositions(e); + const relative = this.konva.stage.getRelativePointerPosition(); + const absolute = this.konva.stage.getPointerPosition(); + + if (!relative || !absolute) { + return false; + } + + this.$cursorPos.set({ relative, absolute }); + return true; + }; + getClip = ( entity: CanvasRegionalGuidanceState | CanvasControlLayerState | CanvasRasterLayerState | CanvasInpaintMaskState ) => { @@ -274,6 +305,7 @@ export class CanvasToolModule extends CanvasModuleBase { window.addEventListener('keydown', this.onKeyDown); window.addEventListener('keyup', this.onKeyUp); + window.addEventListener('pointermove', this.onWindowPointerMove); window.addEventListener('pointerup', this.onWindowPointerUp); window.addEventListener('blur', this.onWindowBlur); @@ -289,6 +321,7 @@ export class CanvasToolModule extends CanvasModuleBase { window.removeEventListener('keydown', this.onKeyDown); window.removeEventListener('keyup', this.onKeyUp); + window.removeEventListener('pointermove', this.onWindowPointerMove); window.removeEventListener('pointerup', this.onWindowPointerUp); window.removeEventListener('blur', this.onWindowBlur); }; @@ -316,6 +349,18 @@ export class CanvasToolModule extends CanvasModuleBase { return true; } + if (tool === 'lasso') { + if (this.manager.$isBusy.get()) { + return false; + } + + if (this.manager.stage.getIsDragging()) { + return false; + } + + return true; + } + if (this.manager.stateApi.getRenderedEntityCount() === 0) { return false; } @@ -407,6 +452,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.eraser.onStagePointerDown(e); } else if (tool === 'rect') { await this.tools.rect.onStagePointerDown(e); + } else if (tool === 'lasso') { + await this.tools.lasso.onStagePointerDown(e); } else if (tool === 'gradient') { await this.tools.gradient.onStagePointerDown(e); } else if (tool === 'text') { @@ -441,6 +488,8 @@ export class CanvasToolModule extends CanvasModuleBase { this.tools.eraser.onStagePointerUp(e); } else if (tool === 'rect') { this.tools.rect.onStagePointerUp(e); + } else if (tool === 'lasso') { + void this.tools.lasso.onStagePointerUp(e); } else if (tool === 'gradient') { this.tools.gradient.onStagePointerUp(e); } @@ -476,6 +525,8 @@ export class CanvasToolModule extends CanvasModuleBase { await this.tools.eraser.onStagePointerMove(e); } else if (tool === 'rect') { await this.tools.rect.onStagePointerMove(e); + } else if (tool === 'lasso') { + await this.tools.lasso.onStagePointerMove(e); } else if (tool === 'gradient') { await this.tools.gradient.onStagePointerMove(e); } else if (tool === 'text') { @@ -560,6 +611,7 @@ export class CanvasToolModule extends CanvasModuleBase { onWindowPointerUp = (_: PointerEvent) => { try { this.$isPrimaryPointerDown.set(false); + void this.tools.lasso.onWindowPointerUp(); const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if (selectedEntity && selectedEntity.bufferRenderer.hasBuffer() && !this.manager.$isBusy.get()) { @@ -570,6 +622,41 @@ export class CanvasToolModule extends CanvasModuleBase { } }; + onWindowPointerMove = (e: PointerEvent) => { + const target = e.target; + if (target instanceof Node && this.manager.stage.container.contains(target)) { + return; + } + + if (this.$tool.get() !== 'lasso') { + return; + } + + if (!this.getCanDraw()) { + return; + } + + if (!this.$isPrimaryPointerDown.get()) { + return; + } + + if (!this.tools.lasso.hasActiveSession()) { + return; + } + + try { + this.$lastPointerType.set(e.pointerType); + + if (!this.syncCursorPositionsFromWindowEvent(e)) { + return; + } + + this.tools.lasso.onWindowPointerMove(e); + } finally { + this.render(); + } + }; + /** * We want to reset any "quick-switch" tool selection on window blur. Fixes an issue where you alt-tab out of the app * and the color picker tool is still active when you come back. @@ -579,6 +666,7 @@ export class CanvasToolModule extends CanvasModuleBase { }; onKeyDown = (e: KeyboardEvent) => { + const isSpaceKey = e.key === KEY_SPACE || e.code === CODE_SPACE; if (e.target instanceof HTMLInputElement || e.target instanceof HTMLTextAreaElement) { return; } @@ -600,6 +688,9 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ESCAPE) { // Cancel shape drawing on escape e.preventDefault(); + if (this.$tool.get() === 'lasso') { + this.tools.lasso.reset(); + } const selectedEntity = this.manager.stateApi.getSelectedEntityAdapter(); if ( selectedEntity && @@ -612,19 +703,27 @@ export class CanvasToolModule extends CanvasModuleBase { return; } - if (e.key === KEY_SPACE) { + if (isSpaceKey) { // Select the view tool on space key down e.preventDefault(); - this.$toolBuffer.set(this.$tool.get()); - this.$tool.set('view'); + e.stopPropagation(); + const currentTool = this.$tool.get(); + this.$toolBuffer.set(currentTool); this.manager.stateApi.$spaceKey.set(true); - this.$cursorPos.set(null); + this.$tool.set('view'); + if (currentTool === 'lasso' && this.tools.lasso.hasActiveSession() && this.$isPrimaryPointerDown.get()) { + // Start panning immediately if user is already drawing with freehand lasso. + this.manager.stage.startDragging(); + } else { + this.$cursorPos.set(null); + } return; } if (e.key === KEY_ALT) { // Select the color picker on alt key down e.preventDefault(); + e.stopPropagation(); this.$toolBuffer.set(this.$tool.get()); this.$tool.set('colorPicker'); } @@ -644,9 +743,10 @@ export class CanvasToolModule extends CanvasModuleBase { return; } - if (e.key === KEY_SPACE) { + if (e.key === KEY_SPACE || e.code === CODE_SPACE) { // Revert the tool to the previous tool on space key up e.preventDefault(); + e.stopPropagation(); this.revertToolBuffer(); this.manager.stateApi.$spaceKey.set(false); return; @@ -655,6 +755,7 @@ export class CanvasToolModule extends CanvasModuleBase { if (e.key === KEY_ALT) { // Revert the tool to the previous tool on alt key up e.preventDefault(); + e.stopPropagation(); this.revertToolBuffer(); return; } @@ -684,6 +785,7 @@ export class CanvasToolModule extends CanvasModuleBase { eraser: this.tools.eraser.repr(), colorPicker: this.tools.colorPicker.repr(), rect: this.tools.rect.repr(), + lasso: this.tools.lasso.repr(), gradient: this.tools.gradient.repr(), bbox: this.tools.bbox.repr(), view: this.tools.view.repr(), diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts index 91428b4521..202b70e142 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSettingsSlice.ts @@ -13,6 +13,7 @@ const zTransformSmoothingMode = z.enum(['bilinear', 'bicubic', 'hamming', 'lancz export type TransformSmoothingMode = z.infer; const zGradientType = z.enum(['linear', 'radial']); +const zLassoMode = z.enum(['freehand', 'polygon']); const zCanvasSettingsState = z.object({ /** @@ -118,6 +119,10 @@ const zCanvasSettingsState = z.object({ * Whether the gradient tool clips to the drag gesture. */ gradientClipEnabled: z.boolean().default(true), + /** + * The lasso tool mode. + */ + lassoMode: zLassoMode.default('freehand'), }); type CanvasSettingsState = z.infer; @@ -148,6 +153,7 @@ const getInitialState = (): CanvasSettingsState => ({ transformSmoothingMode: 'bicubic', gradientType: 'linear', gradientClipEnabled: true, + lassoMode: 'freehand', }); const slice = createSlice({ @@ -245,6 +251,9 @@ const slice = createSlice({ settingsGradientClipToggled: (state) => { state.gradientClipEnabled = !state.gradientClipEnabled; }, + settingsLassoModeChanged: (state, action: PayloadAction) => { + state.lassoMode = action.payload; + }, }, }); @@ -276,6 +285,7 @@ export const { settingsFillColorPickerPinnedSet, settingsGradientTypeChanged, settingsGradientClipToggled, + settingsLassoModeChanged, } = slice.actions; export const canvasSettingsSliceConfig: SliceConfig = { @@ -317,3 +327,4 @@ export const selectTransformSmoothingEnabled = createCanvasSettingsSelector( export const selectTransformSmoothingMode = createCanvasSettingsSelector((settings) => settings.transformSmoothingMode); export const selectGradientType = createCanvasSettingsSelector((settings) => settings.gradientType); export const selectGradientClipEnabled = createCanvasSettingsSelector((settings) => settings.gradientClipEnabled); +export const selectLassoMode = createCanvasSettingsSelector((settings) => settings.lassoMode); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index 79d3963d12..d114568c3c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -66,6 +66,7 @@ import type { EntityEraserLineAddedPayload, EntityGradientAddedPayload, EntityIdentifierPayload, + EntityLassoAddedPayload, EntityMovedToPayload, EntityRasterizedPayload, EntityRectAddedPayload, @@ -99,6 +100,12 @@ import { makeDefaultRasterLayerAdjustments, } from './util'; +const resetInpaintMasksHiddenIfEmpty = (state: CanvasState) => { + if (state.inpaintMasks.entities.length === 0) { + state.inpaintMasks.isHidden = false; + } +}; + const slice = createSlice({ name: 'canvas', initialState: getInitialCanvasState(), @@ -1061,6 +1068,7 @@ const slice = createSlice({ (entity) => !mergedEntitiesToDelete.includes(entity.id) ); } + resetInpaintMasksHiddenIfEmpty(state); const entityIdentifier = getEntityIdentifier(entityState); if (isSelected || mergedEntitiesToDelete.length > 0) { @@ -1132,6 +1140,7 @@ const slice = createSlice({ if (replace) { // Remove the inpaint mask state.inpaintMasks.entities = state.inpaintMasks.entities.filter((layer) => layer.id !== entityIdentifier.id); + resetInpaintMasksHiddenIfEmpty(state); } // Add the new regional guidance @@ -1548,6 +1557,17 @@ const slice = createSlice({ // re-render it (reference equality check). I don't like this behaviour. entity.objects.push({ ...rect }); }, + entityLassoAdded: (state, action: PayloadAction) => { + const { entityIdentifier, lasso } = action.payload; + const entity = selectEntity(state, entityIdentifier); + if (!entity) { + return; + } + + // TODO(psyche): If we add the object without splatting, the renderer will see it as the same object and not + // re-render it (reference equality check). I don't like this behaviour. + entity.objects.push({ ...lasso }); + }, entityGradientAdded: (state, action: PayloadAction) => { const { entityIdentifier, gradient } = action.payload; const entity = selectEntity(state, entityIdentifier); @@ -1590,6 +1610,7 @@ const slice = createSlice({ break; } + resetInpaintMasksHiddenIfEmpty(state); state.selectedEntityIdentifier = selectedEntityIdentifier; }, entityArrangedForwardOne: (state, action: PayloadAction) => { @@ -1678,6 +1699,7 @@ const slice = createSlice({ break; case 'inpaint_mask': state.inpaintMasks.isHidden = !state.inpaintMasks.isHidden; + resetInpaintMasksHiddenIfEmpty(state); break; case 'regional_guidance': state.regionalGuidance.isHidden = !state.regionalGuidance.isHidden; @@ -1686,13 +1708,16 @@ const slice = createSlice({ }, allNonRasterLayersIsHiddenToggled: (state) => { const hasVisibleNonRasterLayers = - !state.controlLayers.isHidden || !state.inpaintMasks.isHidden || !state.regionalGuidance.isHidden; + (state.controlLayers.entities.length > 0 && !state.controlLayers.isHidden) || + (state.inpaintMasks.entities.length > 0 && !state.inpaintMasks.isHidden) || + (state.regionalGuidance.entities.length > 0 && !state.regionalGuidance.isHidden); const shouldHide = hasVisibleNonRasterLayers; state.controlLayers.isHidden = shouldHide; state.inpaintMasks.isHidden = shouldHide; state.regionalGuidance.isHidden = shouldHide; + resetInpaintMasksHiddenIfEmpty(state); }, allEntitiesDeleted: (state) => { // Deleting all entities is equivalent to resetting the state for each entity type @@ -1708,6 +1733,7 @@ const slice = createSlice({ state.inpaintMasks.entities = inpaintMasks; state.rasterLayers.entities = rasterLayers; state.regionalGuidance.entities = regionalGuidance; + resetInpaintMasksHiddenIfEmpty(state); return state; }, canvasUndo: () => {}, @@ -1787,6 +1813,7 @@ export const { entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, + entityLassoAdded, entityGradientAdded, // Raster layer adjustments rasterLayerAdjustmentsSet, @@ -1913,7 +1940,13 @@ export const canvasSliceConfig: SliceConfig = { }, }; -const doNotGroupMatcher = isAnyOf(entityBrushLineAdded, entityEraserLineAdded, entityRectAdded, entityGradientAdded); +const doNotGroupMatcher = isAnyOf( + entityBrushLineAdded, + entityEraserLineAdded, + entityRectAdded, + entityLassoAdded, + entityGradientAdded +); // Store rapid actions of the same type at most once every x time. // See: https://github.com/omnidan/redux-undo/blob/master/examples/throttled-drag/util/undoFilter.js diff --git a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts index 5c0abfdb89..2e2ae09212 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/selectors.ts @@ -361,5 +361,9 @@ export const selectCanvasMetadata = createSelector( * This is used to determine the state of the toggle button that shows/hides all non-raster layers. */ export const selectNonRasterLayersIsHidden = createSelector(selectCanvasSlice, (canvas) => { - return canvas.controlLayers.isHidden && canvas.inpaintMasks.isHidden && canvas.regionalGuidance.isHidden; + const areControlLayersEffectivelyHidden = canvas.controlLayers.entities.length === 0 || canvas.controlLayers.isHidden; + const areInpaintMasksEffectivelyHidden = canvas.inpaintMasks.entities.length === 0 || canvas.inpaintMasks.isHidden; + const areRegionalGuidanceEffectivelyHidden = + canvas.regionalGuidance.entities.length === 0 || canvas.regionalGuidance.isHidden; + return areControlLayersEffectivelyHidden && areInpaintMasksEffectivelyHidden && areRegionalGuidanceEffectivelyHidden; }); diff --git a/invokeai/frontend/web/src/features/controlLayers/store/types.ts b/invokeai/frontend/web/src/features/controlLayers/store/types.ts index 7247f4cf86..c16ffdbeab 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/types.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/types.ts @@ -105,7 +105,7 @@ const zIPMethodV2 = z.enum(['full', 'style', 'composition', 'style_strong', 'sty export type IPMethodV2 = z.infer; export const isIPMethodV2 = (v: unknown): v is IPMethodV2 => zIPMethodV2.safeParse(v).success; -const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'gradient', 'view', 'bbox', 'colorPicker', 'text']); +const _zTool = z.enum(['brush', 'eraser', 'move', 'rect', 'lasso', 'gradient', 'view', 'bbox', 'colorPicker', 'text']); export type Tool = z.infer; const zPoints = z.array(z.number()).refine((points) => points.length % 2 === 0, { @@ -260,6 +260,20 @@ const zCanvasRectState = z.object({ }); export type CanvasRectState = z.infer; +const zCanvasLassoCompositeOperation = z.enum(['source-over', 'destination-out']); + +const zCanvasLassoState = z.object({ + id: zId, + type: z.literal('lasso'), + /** + * Points in the format [x1, y1, x2, y2, ...]. + * The lasso tool always commits a closed contour. + */ + points: zPoints, + compositeOperation: zCanvasLassoCompositeOperation.default('source-over'), +}); +export type CanvasLassoState = z.infer; + // Gradient state includes clip metadata so the tool can optionally clip to drag gesture. const zCanvasLinearGradientState = z.object({ id: zId, @@ -309,6 +323,7 @@ const zCanvasObjectState = z.union([ zCanvasBrushLineState, zCanvasEraserLineState, zCanvasRectState, + zCanvasLassoState, zCanvasBrushLineWithPressureState, zCanvasEraserLineWithPressureState, zCanvasGradientState, @@ -955,6 +970,7 @@ export type EntityEraserLineAddedPayload = EntityIdentifierPayload<{ eraserLine: CanvasEraserLineState | CanvasEraserLineWithPressureState; }>; export type EntityRectAddedPayload = EntityIdentifierPayload<{ rect: CanvasRectState }>; +export type EntityLassoAddedPayload = EntityIdentifierPayload<{ lasso: CanvasLassoState }>; export type EntityGradientAddedPayload = EntityIdentifierPayload<{ gradient: CanvasGradientState }>; export type EntityRasterizedPayload = EntityIdentifierPayload<{ imageObject: CanvasImageState; diff --git a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts index c6a0bd6705..dfc9b5d280 100644 --- a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts +++ b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts @@ -118,6 +118,7 @@ export const useHotkeyData = (): HotkeysData => { addHotkey('canvas', 'selectEraserTool', ['e']); addHotkey('canvas', 'selectMoveTool', ['v']); addHotkey('canvas', 'selectRectTool', ['u']); + addHotkey('canvas', 'selectLassoTool', ['l']); addHotkey('canvas', 'selectViewTool', ['h']); addHotkey('canvas', 'selectColorPickerTool', ['i']); addHotkey('canvas', 'setFillColorsToDefault', ['d']); From 06a1881bbda152aece862750937d08c5c28cc07f Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 14 Apr 2026 02:38:47 +0200 Subject: [PATCH 4/6] feat(ui): group nodes by category in add-node dialog (#8912) * feat(ui): group nodes by category in add-node dialog Add collapsible category grouping to the node picker command palette. Categories are parsed from the backend schema and displayed as expandable sections with caret icons. All categories auto-expand when searching. * feat(ui): add toggle for category grouping in add-node dialog and prioritize exact matches Add a persistent "Group Nodes by Category" setting to workflow editor settings, allowing users to switch between grouped and flat node list views. Also sort exact title matches to the top when searching. * fix: update test schema categories to match expected templates * feat: add expand/collapse all buttons to node picker and fix node categories Add "Expand All" and "Collapse All" link-buttons above the grouped category list in the add-node dialog so users can quickly open or close all categories at once. Buttons are hidden during search since categories auto-expand while searching. Fix two miscategorized nodes: Z-Image ControlNet was in "Control" instead of "Controlnet", and Upscale (RealESRGAN) was in "Esrgan" instead of "Upscale". * refactor(nodes): clean up node category taxonomy Reorganize all built-in invocation categories into a consistent set of 18 groups (model, prompt, conditioning, controlnet_preprocessors, latents, image, mask, inpaint, tiles, upscale, segmentation, math, strings, primitives, batch, metadata, multimodal, canvas). - Move denoise/i2l/l2i nodes consistently into "latents" - Move all mask creation/manipulation nodes into "mask" - Split ControlNet preprocessors out of "controlnet" into their own group - Fold "unet", "vllm", "string", "ip_adapter", "t2i_adapter" into larger groups - Move metadata_linked denoise wrappers from "latents" to "metadata" - Add missing category to ideal_size - Introduce dedicated "canvas" group for canvas/output/panel nodes Also adds the now-required `category` field to invocation template fixtures in validateConnection.test.ts. * Chore Ruff Format --------- Co-authored-by: dunkeroni --- invokeai/app/invocations/batch.py | 16 +- invokeai/app/invocations/canny.py | 2 +- invokeai/app/invocations/cogview4_denoise.py | 2 +- .../invocations/cogview4_image_to_latents.py | 2 +- .../app/invocations/cogview4_text_encoder.py | 2 +- invokeai/app/invocations/collections.py | 8 +- invokeai/app/invocations/color_map.py | 2 +- invokeai/app/invocations/compel.py | 8 +- invokeai/app/invocations/content_shuffle.py | 2 +- invokeai/app/invocations/controlnet.py | 4 +- .../app/invocations/create_denoise_mask.py | 2 +- .../app/invocations/create_gradient_mask.py | 2 +- invokeai/app/invocations/depth_anything.py | 2 +- invokeai/app/invocations/dw_openpose.py | 2 +- invokeai/app/invocations/facetools.py | 14 +- invokeai/app/invocations/flux2_denoise.py | 2 +- .../invocations/flux2_klein_text_encoder.py | 2 +- invokeai/app/invocations/flux_controlnet.py | 2 +- invokeai/app/invocations/flux_denoise.py | 2 +- invokeai/app/invocations/flux_fill.py | 2 +- invokeai/app/invocations/flux_ip_adapter.py | 2 +- invokeai/app/invocations/flux_redux.py | 2 +- invokeai/app/invocations/flux_text_encoder.py | 2 +- invokeai/app/invocations/grounding_dino.py | 2 +- invokeai/app/invocations/hed.py | 2 +- invokeai/app/invocations/ideal_size.py | 1 + invokeai/app/invocations/image.py | 20 +- invokeai/app/invocations/image_panels.py | 2 +- invokeai/app/invocations/ip_adapter.py | 2 +- invokeai/app/invocations/lineart.py | 2 +- invokeai/app/invocations/lineart_anime.py | 2 +- .../app/invocations/llava_onevision_vllm.py | 2 +- invokeai/app/invocations/logic.py | 2 +- invokeai/app/invocations/mask.py | 8 +- invokeai/app/invocations/mediapipe_face.py | 2 +- invokeai/app/invocations/metadata_linked.py | 6 +- invokeai/app/invocations/mlsd.py | 2 +- invokeai/app/invocations/model.py | 2 +- invokeai/app/invocations/normal_bae.py | 2 +- invokeai/app/invocations/pbr_maps.py | 4 +- invokeai/app/invocations/pidi.py | 2 +- invokeai/app/invocations/sd3_denoise.py | 2 +- .../app/invocations/sd3_image_to_latents.py | 2 +- invokeai/app/invocations/sd3_text_encoder.py | 2 +- invokeai/app/invocations/strings.py | 12 +- invokeai/app/invocations/t2i_adapter.py | 2 +- invokeai/app/invocations/upscale.py | 2 +- invokeai/app/invocations/z_image_control.py | 2 +- invokeai/app/invocations/z_image_denoise.py | 2 +- .../invocations/z_image_image_to_latents.py | 2 +- .../z_image_seed_variance_enhancer.py | 2 +- .../app/invocations/z_image_text_encoder.py | 2 +- invokeai/frontend/web/public/locales/en.json | 4 + .../flow/AddNodeCmdk/AddNodeCmdk.tsx | 377 +++++++++++++----- .../TopRightPanel/WorkflowEditorSettings.tsx | 18 + .../features/nodes/store/util/testUtils.ts | 10 +- .../store/util/validateConnection.test.ts | 3 + .../nodes/store/workflowSettingsSlice.ts | 10 + .../src/features/nodes/types/invocation.ts | 1 + .../features/nodes/util/schema/parseSchema.ts | 2 + 60 files changed, 424 insertions(+), 184 deletions(-) diff --git a/invokeai/app/invocations/batch.py b/invokeai/app/invocations/batch.py index 34ecd38f26..f79b8816ad 100644 --- a/invokeai/app/invocations/batch.py +++ b/invokeai/app/invocations/batch.py @@ -56,7 +56,7 @@ class BaseBatchInvocation(BaseInvocation): "image_batch", title="Image Batch", tags=["primitives", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -87,7 +87,7 @@ class ImageGeneratorField(BaseModel): "image_generator", title="Image Generator", tags=["primitives", "board", "image", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -111,7 +111,7 @@ class ImageGenerator(BaseInvocation): "string_batch", title="String Batch", tags=["primitives", "string", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -142,7 +142,7 @@ class StringGeneratorField(BaseModel): "string_generator", title="String Generator", tags=["primitives", "string", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -166,7 +166,7 @@ class StringGenerator(BaseInvocation): "integer_batch", title="Integer Batch", tags=["primitives", "integer", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -195,7 +195,7 @@ class IntegerGeneratorField(BaseModel): "integer_generator", title="Integer Generator", tags=["primitives", "int", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -219,7 +219,7 @@ class IntegerGenerator(BaseInvocation): "float_batch", title="Float Batch", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) @@ -250,7 +250,7 @@ class FloatGeneratorField(BaseModel): "float_generator", title="Float Generator", tags=["primitives", "float", "number", "batch", "special"], - category="primitives", + category="batch", version="1.0.0", classification=Classification.Special, ) diff --git a/invokeai/app/invocations/canny.py b/invokeai/app/invocations/canny.py index 0cdc386e62..dbfde6d353 100644 --- a/invokeai/app/invocations/canny.py +++ b/invokeai/app/invocations/canny.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import cv2_to_pil, pil_to_cv2 "canny_edge_detection", title="Canny Edge Detection", tags=["controlnet", "canny"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class CannyEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/cogview4_denoise.py b/invokeai/app/invocations/cogview4_denoise.py index 070d8a3478..e8b910f731 100644 --- a/invokeai/app/invocations/cogview4_denoise.py +++ b/invokeai/app/invocations/cogview4_denoise.py @@ -33,7 +33,7 @@ from invokeai.backend.util.devices import TorchDevice "cogview4_denoise", title="Denoise - CogView4", tags=["image", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_image_to_latents.py b/invokeai/app/invocations/cogview4_image_to_latents.py index 630b9ab1e3..facbc38dd4 100644 --- a/invokeai/app/invocations/cogview4_image_to_latents.py +++ b/invokeai/app/invocations/cogview4_image_to_latents.py @@ -27,7 +27,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "cogview4_i2l", title="Image to Latents - CogView4", tags=["image", "latents", "vae", "i2l", "cogview4"], - category="image", + category="latents", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/cogview4_text_encoder.py b/invokeai/app/invocations/cogview4_text_encoder.py index 3b5b1dc73f..13234889fb 100644 --- a/invokeai/app/invocations/cogview4_text_encoder.py +++ b/invokeai/app/invocations/cogview4_text_encoder.py @@ -20,7 +20,7 @@ COGVIEW4_GLM_MAX_SEQ_LEN = 1024 "cogview4_text_encoder", title="Prompt - CogView4", tags=["prompt", "conditioning", "cogview4"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/collections.py b/invokeai/app/invocations/collections.py index bd3dedb3f8..39e77f5b63 100644 --- a/invokeai/app/invocations/collections.py +++ b/invokeai/app/invocations/collections.py @@ -11,9 +11,7 @@ from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.misc import SEED_MAX -@invocation( - "range", title="Integer Range", tags=["collection", "integer", "range"], category="collections", version="1.0.0" -) +@invocation("range", title="Integer Range", tags=["collection", "integer", "range"], category="batch", version="1.0.0") class RangeInvocation(BaseInvocation): """Creates a range of numbers from start to stop with step""" @@ -35,7 +33,7 @@ class RangeInvocation(BaseInvocation): "range_of_size", title="Integer Range of Size", tags=["collection", "integer", "size", "range"], - category="collections", + category="batch", version="1.0.0", ) class RangeOfSizeInvocation(BaseInvocation): @@ -55,7 +53,7 @@ class RangeOfSizeInvocation(BaseInvocation): "random_range", title="Random Range", tags=["range", "integer", "random", "collection"], - category="collections", + category="batch", version="1.0.1", use_cache=False, ) diff --git a/invokeai/app/invocations/color_map.py b/invokeai/app/invocations/color_map.py index e55584caf5..ec95acfffd 100644 --- a/invokeai/app/invocations/color_map.py +++ b/invokeai/app/invocations/color_map.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.util import np_to_pil, pil_to_np "color_map", title="Color Map", tags=["controlnet"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ColorMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 0ff6be969f..99373531d8 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -43,7 +43,7 @@ from invokeai.backend.util.devices import TorchDevice "compel", title="Prompt - SD1.5", tags=["prompt", "compel"], - category="conditioning", + category="prompt", version="1.2.1", ) class CompelInvocation(BaseInvocation): @@ -248,7 +248,7 @@ class SDXLPromptInvocationBase: "sdxl_compel_prompt", title="Prompt - SDXL", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.2.1", ) class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -342,7 +342,7 @@ class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): "sdxl_refiner_compel_prompt", title="Prompt - SDXL Refiner", tags=["sdxl", "compel", "prompt"], - category="conditioning", + category="prompt", version="1.1.2", ) class SDXLRefinerCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase): @@ -391,7 +391,7 @@ class CLIPSkipInvocationOutput(BaseInvocationOutput): "clip_skip", title="Apply CLIP Skip - SD1.5, SDXL", tags=["clipskip", "clip", "skip"], - category="conditioning", + category="prompt", version="1.1.1", ) class CLIPSkipInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/content_shuffle.py b/invokeai/app/invocations/content_shuffle.py index e01096ecea..6fd35b53eb 100644 --- a/invokeai/app/invocations/content_shuffle.py +++ b/invokeai/app/invocations/content_shuffle.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.content_shuffle import content_shuffle "content_shuffle", title="Content Shuffle", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class ContentShuffleInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/controlnet.py b/invokeai/app/invocations/controlnet.py index d1878d967e..9b0fc8219b 100644 --- a/invokeai/app/invocations/controlnet.py +++ b/invokeai/app/invocations/controlnet.py @@ -64,7 +64,7 @@ class ControlOutput(BaseInvocationOutput): @invocation( - "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="controlnet", version="1.1.3" + "controlnet", title="ControlNet - SD1.5, SD2, SDXL", tags=["controlnet"], category="conditioning", version="1.1.3" ) class ControlNetInvocation(BaseInvocation): """Collects ControlNet info to pass to other nodes""" @@ -116,7 +116,7 @@ class ControlNetInvocation(BaseInvocation): "heuristic_resize", title="Heuristic Resize", tags=["image, controlnet"], - category="image", + category="controlnet_preprocessors", version="1.1.1", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/create_denoise_mask.py b/invokeai/app/invocations/create_denoise_mask.py index d013e8f4f6..419a516bcd 100644 --- a/invokeai/app/invocations/create_denoise_mask.py +++ b/invokeai/app/invocations/create_denoise_mask.py @@ -18,7 +18,7 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_t "create_denoise_mask", title="Create Denoise Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.0.2", ) class CreateDenoiseMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/create_gradient_mask.py b/invokeai/app/invocations/create_gradient_mask.py index 8a7e7c5231..08826cc5ef 100644 --- a/invokeai/app/invocations/create_gradient_mask.py +++ b/invokeai/app/invocations/create_gradient_mask.py @@ -41,7 +41,7 @@ class GradientMaskOutput(BaseInvocationOutput): "create_gradient_mask", title="Create Gradient Mask", tags=["mask", "denoise"], - category="latents", + category="mask", version="1.3.0", ) class CreateGradientMaskInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/depth_anything.py b/invokeai/app/invocations/depth_anything.py index af79413ce0..1fd808efde 100644 --- a/invokeai/app/invocations/depth_anything.py +++ b/invokeai/app/invocations/depth_anything.py @@ -20,7 +20,7 @@ DEPTH_ANYTHING_MODELS = { "depth_anything_depth_estimation", title="Depth Anything Depth Estimation", tags=["controlnet", "depth", "depth anything"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class DepthAnythingDepthEstimationInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/dw_openpose.py b/invokeai/app/invocations/dw_openpose.py index 225c7e2283..918a4bc4d0 100644 --- a/invokeai/app/invocations/dw_openpose.py +++ b/invokeai/app/invocations/dw_openpose.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector "dw_openpose_detection", title="DW Openpose Detection", tags=["controlnet", "dwpose", "openpose"], - category="controlnet", + category="controlnet_preprocessors", version="1.1.1", ) class DWOpenposeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/facetools.py b/invokeai/app/invocations/facetools.py index 987f1b1e40..1092a67ce9 100644 --- a/invokeai/app/invocations/facetools.py +++ b/invokeai/app/invocations/facetools.py @@ -435,7 +435,9 @@ def get_faces_list( return all_faces -@invocation("face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_off", title="FaceOff", tags=["image", "faceoff", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceOffInvocation(BaseInvocation, WithMetadata): """Bound, extract, and mask a face from an image using MediaPipe detection""" @@ -514,7 +516,9 @@ class FaceOffInvocation(BaseInvocation, WithMetadata): return output -@invocation("face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="image", version="1.2.2") +@invocation( + "face_mask_detection", title="FaceMask", tags=["image", "face", "mask"], category="segmentation", version="1.2.2" +) class FaceMaskInvocation(BaseInvocation, WithMetadata): """Face mask creation using mediapipe face detection""" @@ -617,7 +621,11 @@ class FaceMaskInvocation(BaseInvocation, WithMetadata): @invocation( - "face_identifier", title="FaceIdentifier", tags=["image", "face", "identifier"], category="image", version="1.2.2" + "face_identifier", + title="FaceIdentifier", + tags=["image", "face", "identifier"], + category="segmentation", + version="1.2.2", ) class FaceIdentifierInvocation(BaseInvocation, WithMetadata, WithBoard): """Outputs an image with detected face IDs printed on each face. For use with other FaceTools.""" diff --git a/invokeai/app/invocations/flux2_denoise.py b/invokeai/app/invocations/flux2_denoise.py index c387a72790..1b5ea372d6 100644 --- a/invokeai/app/invocations/flux2_denoise.py +++ b/invokeai/app/invocations/flux2_denoise.py @@ -53,7 +53,7 @@ from invokeai.backend.util.devices import TorchDevice "flux2_denoise", title="FLUX2 Denoise", tags=["image", "flux", "flux2", "klein", "denoise"], - category="image", + category="latents", version="1.4.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/flux2_klein_text_encoder.py b/invokeai/app/invocations/flux2_klein_text_encoder.py index b44e782c8a..b2728d1d7c 100644 --- a/invokeai/app/invocations/flux2_klein_text_encoder.py +++ b/invokeai/app/invocations/flux2_klein_text_encoder.py @@ -45,7 +45,7 @@ KLEIN_MAX_SEQ_LEN = 512 "flux2_klein_text_encoder", title="Prompt - Flux2 Klein", tags=["prompt", "conditioning", "flux", "klein", "qwen3"], - category="conditioning", + category="prompt", version="1.1.1", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/flux_controlnet.py b/invokeai/app/invocations/flux_controlnet.py index 8228484375..b11d497f31 100644 --- a/invokeai/app/invocations/flux_controlnet.py +++ b/invokeai/app/invocations/flux_controlnet.py @@ -50,7 +50,7 @@ class FluxControlNetOutput(BaseInvocationOutput): "flux_controlnet", title="FLUX ControlNet", tags=["controlnet", "flux"], - category="controlnet", + category="conditioning", version="1.0.0", ) class FluxControlNetInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b..84f0a030c5 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -70,7 +70,7 @@ from invokeai.backend.util.devices import TorchDevice "flux_denoise", title="FLUX Denoise", tags=["image", "flux"], - category="image", + category="latents", version="4.5.1", ) class FluxDenoiseInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_fill.py b/invokeai/app/invocations/flux_fill.py index cff8f2b1e5..440f3e5c97 100644 --- a/invokeai/app/invocations/flux_fill.py +++ b/invokeai/app/invocations/flux_fill.py @@ -29,7 +29,7 @@ class FluxFillOutput(BaseInvocationOutput): "flux_fill", title="FLUX Fill Conditioning", tags=["inpaint"], - category="inpaint", + category="conditioning", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_ip_adapter.py b/invokeai/app/invocations/flux_ip_adapter.py index 4a1997c512..c0d797d0bd 100644 --- a/invokeai/app/invocations/flux_ip_adapter.py +++ b/invokeai/app/invocations/flux_ip_adapter.py @@ -24,7 +24,7 @@ from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType "flux_ip_adapter", title="FLUX IP-Adapter", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.0.0", ) class FluxIPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/flux_redux.py b/invokeai/app/invocations/flux_redux.py index 403d78b078..b68e9911c5 100644 --- a/invokeai/app/invocations/flux_redux.py +++ b/invokeai/app/invocations/flux_redux.py @@ -47,7 +47,7 @@ DOWNSAMPLING_FUNCTIONS = Literal["nearest", "bilinear", "bicubic", "area", "near "flux_redux", title="FLUX Redux", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="2.1.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 56ebbe7fd9..8b3b33fad1 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -28,7 +28,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit "flux_text_encoder", title="Prompt - FLUX", tags=["prompt", "conditioning", "flux"], - category="conditioning", + category="prompt", version="1.1.2", ) class FluxTextEncoderInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/grounding_dino.py b/invokeai/app/invocations/grounding_dino.py index 1e3d5cea0c..4d900c5034 100644 --- a/invokeai/app/invocations/grounding_dino.py +++ b/invokeai/app/invocations/grounding_dino.py @@ -24,7 +24,7 @@ GROUNDING_DINO_MODEL_IDS: dict[GroundingDinoModelKey, str] = { "grounding_dino", title="Grounding DINO (Text Prompt Object Detection)", tags=["prompt", "object detection"], - category="image", + category="segmentation", version="1.0.0", ) class GroundingDinoInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/hed.py b/invokeai/app/invocations/hed.py index 5ea6e8df1f..e2b68143e5 100644 --- a/invokeai/app/invocations/hed.py +++ b/invokeai/app/invocations/hed.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.hed import ControlNetHED_Apache2, HEDEdgeDetect "hed_edge_detection", title="HED Edge Detection", tags=["controlnet", "hed", "softedge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class HEDEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/ideal_size.py b/invokeai/app/invocations/ideal_size.py index aae3a37c8e..5cfa9c04d0 100644 --- a/invokeai/app/invocations/ideal_size.py +++ b/invokeai/app/invocations/ideal_size.py @@ -21,6 +21,7 @@ class IdealSizeOutput(BaseInvocationOutput): "ideal_size", title="Ideal Size - SD1.5, SDXL", tags=["latents", "math", "ideal_size"], + category="latents", version="1.0.6", ) class IdealSizeInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/image.py b/invokeai/app/invocations/image.py index d4a1977319..17576a0296 100644 --- a/invokeai/app/invocations/image.py +++ b/invokeai/app/invocations/image.py @@ -197,7 +197,7 @@ class ImagePasteInvocation(BaseInvocation, WithMetadata, WithBoard): "tomask", title="Mask from Alpha", tags=["image", "mask"], - category="image", + category="mask", version="1.2.2", ) class MaskFromAlphaInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -604,7 +604,7 @@ class DecodeInvisibleWatermarkInvocation(BaseInvocation): "mask_edge", title="Mask Edge", tags=["image", "mask", "inpaint"], - category="image", + category="mask", version="1.2.2", ) class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -643,7 +643,7 @@ class MaskEdgeInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_combine", title="Combine Masks", tags=["image", "mask", "multiply"], - category="image", + category="mask", version="1.2.2", ) class MaskCombineInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -974,7 +974,7 @@ class ImageChannelMultiplyInvocation(BaseInvocation, WithMetadata, WithBoard): "save_image", title="Save Image", tags=["primitives", "image"], - category="primitives", + category="image", version="1.2.2", use_cache=False, ) @@ -995,7 +995,7 @@ class SaveImageInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_paste_back", title="Canvas Paste Back", tags=["image", "combine"], - category="image", + category="canvas", version="1.0.1", ) class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1032,7 +1032,7 @@ class CanvasPasteBackInvocation(BaseInvocation, WithMetadata, WithBoard): "mask_from_id", title="Mask from Segmented Image", tags=["image", "mask", "id"], - category="image", + category="mask", version="1.0.1", ) class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1069,7 +1069,7 @@ class MaskFromIDInvocation(BaseInvocation, WithMetadata, WithBoard): "canvas_v2_mask_and_crop", title="Canvas V2 Mask and Crop", tags=["image", "mask", "id"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Deprecated, ) @@ -1110,7 +1110,7 @@ class CanvasV2MaskAndCropInvocation(BaseInvocation, WithMetadata, WithBoard): @invocation( - "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="image", version="1.0.1" + "expand_mask_with_fade", title="Expand Mask with Fade", tags=["image", "mask"], category="mask", version="1.0.1" ) class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): """Expands a mask with a fade effect. The mask uses black to indicate areas to keep from the generated image and white for areas to discard. @@ -1199,7 +1199,7 @@ class ExpandMaskWithFadeInvocation(BaseInvocation, WithMetadata, WithBoard): "apply_mask_to_image", title="Apply Mask to Image", tags=["image", "mask", "blend"], - category="image", + category="mask", version="1.0.0", ) class ApplyMaskToImageInvocation(BaseInvocation, WithMetadata, WithBoard): @@ -1374,7 +1374,7 @@ class PasteImageIntoBoundingBoxInvocation(BaseInvocation, WithMetadata, WithBoar "flux_kontext_image_prep", title="FLUX Kontext Image Prep", tags=["image", "concatenate", "flux", "kontext"], - category="image", + category="conditioning", version="1.0.0", ) class FluxKontextConcatenateImagesInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/image_panels.py b/invokeai/app/invocations/image_panels.py index bb9aa4995a..71fefbd1c6 100644 --- a/invokeai/app/invocations/image_panels.py +++ b/invokeai/app/invocations/image_panels.py @@ -23,7 +23,7 @@ class ImagePanelCoordinateOutput(BaseInvocationOutput): "image_panel_layout", title="Image Panel Layout", tags=["image", "panel", "layout"], - category="image", + category="canvas", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index 2b2931e78f..711f910d58 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -73,7 +73,7 @@ CLIP_VISION_MODEL_MAP: dict[Literal["ViT-L", "ViT-H", "ViT-G"], StarterModel] = "ip_adapter", title="IP-Adapter - SD1.5, SDXL", tags=["ip_adapter", "control"], - category="ip_adapter", + category="conditioning", version="1.5.1", ) class IPAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/lineart.py b/invokeai/app/invocations/lineart.py index c486c329ec..3ffd51b5b6 100644 --- a/invokeai/app/invocations/lineart.py +++ b/invokeai/app/invocations/lineart.py @@ -11,7 +11,7 @@ from invokeai.backend.image_util.lineart import Generator, LineartEdgeDetector "lineart_edge_detection", title="Lineart Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/lineart_anime.py b/invokeai/app/invocations/lineart_anime.py index 848756b113..f07476491c 100644 --- a/invokeai/app/invocations/lineart_anime.py +++ b/invokeai/app/invocations/lineart_anime.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.lineart_anime import LineartAnimeEdgeDetector, "lineart_anime_edge_detection", title="Lineart Anime Edge Detection", tags=["controlnet", "lineart"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class LineartAnimeEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/llava_onevision_vllm.py b/invokeai/app/invocations/llava_onevision_vllm.py index fbd2420590..ff3b801d37 100644 --- a/invokeai/app/invocations/llava_onevision_vllm.py +++ b/invokeai/app/invocations/llava_onevision_vllm.py @@ -19,7 +19,7 @@ from invokeai.backend.util.devices import TorchDevice "llava_onevision_vllm", title="LLaVA OneVision VLLM", tags=["vllm"], - category="vllm", + category="multimodal", version="1.0.0", classification=Classification.Beta, ) diff --git a/invokeai/app/invocations/logic.py b/invokeai/app/invocations/logic.py index 3197427d4e..7cc98afbbc 100644 --- a/invokeai/app/invocations/logic.py +++ b/invokeai/app/invocations/logic.py @@ -12,7 +12,7 @@ class IfInvocationOutput(BaseInvocationOutput): ) -@invocation("if", title="If", tags=["logic", "conditional"], category="logic", version="1.0.0") +@invocation("if", title="If", tags=["logic", "conditional"], category="math", version="1.0.0") class IfInvocation(BaseInvocation): """Selects between two optional inputs based on a boolean condition.""" diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index 556ab8801d..49749f43b6 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -24,7 +24,7 @@ from invokeai.backend.image_util.util import pil_to_np "rectangle_mask", title="Create Rectangle Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.1", ) class RectangleMaskInvocation(BaseInvocation, WithMetadata): @@ -55,7 +55,7 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): "alpha_mask_to_tensor", title="Alpha Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class AlphaMaskToTensorInvocation(BaseInvocation): @@ -83,7 +83,7 @@ class AlphaMaskToTensorInvocation(BaseInvocation): "invert_tensor_mask", title="Invert Tensor Mask", tags=["conditioning"], - category="conditioning", + category="mask", version="1.1.0", ) class InvertTensorMaskInvocation(BaseInvocation): @@ -115,7 +115,7 @@ class InvertTensorMaskInvocation(BaseInvocation): "image_mask_to_tensor", title="Image Mask to Tensor", tags=["conditioning"], - category="conditioning", + category="mask", version="1.0.0", ) class ImageMaskToTensorInvocation(BaseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mediapipe_face.py b/invokeai/app/invocations/mediapipe_face.py index 89fccfc1ac..e81326463c 100644 --- a/invokeai/app/invocations/mediapipe_face.py +++ b/invokeai/app/invocations/mediapipe_face.py @@ -9,7 +9,7 @@ from invokeai.backend.image_util.mediapipe_face import detect_faces "mediapipe_face_detection", title="MediaPipe Face Detection", tags=["controlnet", "face"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MediaPipeFaceDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/metadata_linked.py b/invokeai/app/invocations/metadata_linked.py index 6a9db3e589..53f2ea7471 100644 --- a/invokeai/app/invocations/metadata_linked.py +++ b/invokeai/app/invocations/metadata_linked.py @@ -621,7 +621,7 @@ class LatentsMetaOutput(LatentsOutput, MetadataOutput): "denoise_latents_meta", title=f"{DenoiseLatentsInvocation.UIConfig.title} + Metadata", tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.1.1", ) class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): @@ -686,7 +686,7 @@ class DenoiseLatentsMetaInvocation(DenoiseLatentsInvocation, WithMetadata): "flux_denoise_meta", title=f"{FluxDenoiseInvocation.UIConfig.title} + Metadata", tags=["flux", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.1", ) class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): @@ -734,7 +734,7 @@ class FluxDenoiseLatentsMetaInvocation(FluxDenoiseInvocation, WithMetadata): "z_image_denoise_meta", title=f"{ZImageDenoiseInvocation.UIConfig.title} + Metadata", tags=["z-image", "latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], - category="latents", + category="metadata", version="1.0.0", ) class ZImageDenoiseMetaInvocation(ZImageDenoiseInvocation, WithMetadata): diff --git a/invokeai/app/invocations/mlsd.py b/invokeai/app/invocations/mlsd.py index 1526350db8..a2446876c8 100644 --- a/invokeai/app/invocations/mlsd.py +++ b/invokeai/app/invocations/mlsd.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.mlsd.models.mbv2_mlsd_large import MobileV2_MLS "mlsd_detection", title="MLSD Detection", tags=["controlnet", "mlsd", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class MLSDDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 6b5afb5529..0c96cdb1d9 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -584,7 +584,7 @@ class SeamlessModeInvocation(BaseInvocation): return SeamlessModeOutput(unet=unet, vae=vae) -@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="unet", version="1.0.2") +@invocation("freeu", title="Apply FreeU - SD1.5, SDXL", tags=["freeu"], category="model", version="1.0.2") class FreeUInvocation(BaseInvocation): """ Applies FreeU to the UNet. Suggested values (b1/b2/s1/s2): diff --git a/invokeai/app/invocations/normal_bae.py b/invokeai/app/invocations/normal_bae.py index ebbea869a1..1159927150 100644 --- a/invokeai/app/invocations/normal_bae.py +++ b/invokeai/app/invocations/normal_bae.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.normal_bae.nets.NNET import NNET "normal_map", title="Normal Map", tags=["controlnet", "normal"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class NormalMapInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/pbr_maps.py b/invokeai/app/invocations/pbr_maps.py index 5e519d38bc..945c3cad59 100644 --- a/invokeai/app/invocations/pbr_maps.py +++ b/invokeai/app/invocations/pbr_maps.py @@ -16,7 +16,9 @@ class PBRMapsOutput(BaseInvocationOutput): displacement_map: ImageField = OutputField(default=None, description="The generated displacement map") -@invocation("pbr_maps", title="PBR Maps", tags=["image", "material"], category="image", version="1.0.0") +@invocation( + "pbr_maps", title="PBR Maps", tags=["image", "material"], category="controlnet_preprocessors", version="1.0.0" +) class PBRMapsInvocation(BaseInvocation, WithMetadata, WithBoard): """Generate Normal, Displacement and Roughness Map from a given image""" diff --git a/invokeai/app/invocations/pidi.py b/invokeai/app/invocations/pidi.py index 47b241ee1f..5d8cab0458 100644 --- a/invokeai/app/invocations/pidi.py +++ b/invokeai/app/invocations/pidi.py @@ -10,7 +10,7 @@ from invokeai.backend.image_util.pidi.model import PiDiNet "pidi_edge_detection", title="PiDiNet Edge Detection", tags=["controlnet", "edge"], - category="controlnet", + category="controlnet_preprocessors", version="1.0.0", ) class PiDiNetEdgeDetectionInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_denoise.py b/invokeai/app/invocations/sd3_denoise.py index b9d69369b7..4b990ee42b 100644 --- a/invokeai/app/invocations/sd3_denoise.py +++ b/invokeai/app/invocations/sd3_denoise.py @@ -34,7 +34,7 @@ from invokeai.backend.util.devices import TorchDevice "sd3_denoise", title="Denoise - SD3", tags=["image", "sd3"], - category="image", + category="latents", version="1.1.1", ) class SD3DenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_image_to_latents.py b/invokeai/app/invocations/sd3_image_to_latents.py index 71a48ee9ad..9af641d8bc 100644 --- a/invokeai/app/invocations/sd3_image_to_latents.py +++ b/invokeai/app/invocations/sd3_image_to_latents.py @@ -24,7 +24,7 @@ from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory "sd3_i2l", title="Image to Latents - SD3", tags=["image", "latents", "vae", "i2l", "sd3"], - category="image", + category="latents", version="1.0.1", ) class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard): diff --git a/invokeai/app/invocations/sd3_text_encoder.py b/invokeai/app/invocations/sd3_text_encoder.py index 58880f9a28..7af138fe45 100644 --- a/invokeai/app/invocations/sd3_text_encoder.py +++ b/invokeai/app/invocations/sd3_text_encoder.py @@ -31,7 +31,7 @@ SD3_T5_MAX_SEQ_LEN = 256 "sd3_text_encoder", title="Prompt - SD3", tags=["prompt", "conditioning", "sd3"], - category="conditioning", + category="prompt", version="1.0.1", ) class Sd3TextEncoderInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/strings.py b/invokeai/app/invocations/strings.py index 2b6bf300b9..6d64e8771a 100644 --- a/invokeai/app/invocations/strings.py +++ b/invokeai/app/invocations/strings.py @@ -20,7 +20,7 @@ class StringPosNegOutput(BaseInvocationOutput): "string_split_neg", title="String Split Negative", tags=["string", "split", "negative"], - category="string", + category="strings", version="1.0.1", ) class StringSplitNegInvocation(BaseInvocation): @@ -63,7 +63,7 @@ class String2Output(BaseInvocationOutput): string_2: str = OutputField(description="string 2") -@invocation("string_split", title="String Split", tags=["string", "split"], category="string", version="1.0.1") +@invocation("string_split", title="String Split", tags=["string", "split"], category="strings", version="1.0.1") class StringSplitInvocation(BaseInvocation): """Splits string into two strings, based on the first occurance of the delimiter. The delimiter will be removed from the string""" @@ -83,7 +83,7 @@ class StringSplitInvocation(BaseInvocation): return String2Output(string_1=part1, string_2=part2) -@invocation("string_join", title="String Join", tags=["string", "join"], category="string", version="1.0.1") +@invocation("string_join", title="String Join", tags=["string", "join"], category="strings", version="1.0.1") class StringJoinInvocation(BaseInvocation): """Joins string left to string right""" @@ -94,7 +94,9 @@ class StringJoinInvocation(BaseInvocation): return StringOutput(value=((self.string_left or "") + (self.string_right or ""))) -@invocation("string_join_three", title="String Join Three", tags=["string", "join"], category="string", version="1.0.1") +@invocation( + "string_join_three", title="String Join Three", tags=["string", "join"], category="strings", version="1.0.1" +) class StringJoinThreeInvocation(BaseInvocation): """Joins string left to string middle to string right""" @@ -107,7 +109,7 @@ class StringJoinThreeInvocation(BaseInvocation): @invocation( - "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="string", version="1.0.1" + "string_replace", title="String Replace", tags=["string", "replace", "regex"], category="strings", version="1.0.1" ) class StringReplaceInvocation(BaseInvocation): """Replaces the search string with the replace string""" diff --git a/invokeai/app/invocations/t2i_adapter.py b/invokeai/app/invocations/t2i_adapter.py index 15f1881eef..cf4b7cda47 100644 --- a/invokeai/app/invocations/t2i_adapter.py +++ b/invokeai/app/invocations/t2i_adapter.py @@ -49,7 +49,7 @@ class T2IAdapterOutput(BaseInvocationOutput): "t2i_adapter", title="T2I-Adapter - SD1.5, SDXL", tags=["t2i_adapter", "control"], - category="t2i_adapter", + category="conditioning", version="1.0.4", ) class T2IAdapterInvocation(BaseInvocation): diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index e7b3968aec..64e372a0f6 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -30,7 +30,7 @@ ESRGAN_MODEL_URLS: dict[str, str] = { } -@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.2") +@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="upscale", version="1.3.2") class ESRGANInvocation(BaseInvocation, WithMetadata, WithBoard): """Upscales an image using RealESRGAN.""" diff --git a/invokeai/app/invocations/z_image_control.py b/invokeai/app/invocations/z_image_control.py index 3b01f12373..f51c2fcd16 100644 --- a/invokeai/app/invocations/z_image_control.py +++ b/invokeai/app/invocations/z_image_control.py @@ -57,7 +57,7 @@ class ZImageControlOutput(BaseInvocationOutput): "z_image_control", title="Z-Image ControlNet", tags=["image", "z-image", "control", "controlnet"], - category="control", + category="conditioning", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_denoise.py b/invokeai/app/invocations/z_image_denoise.py index 24b135e447..397e917112 100644 --- a/invokeai/app/invocations/z_image_denoise.py +++ b/invokeai/app/invocations/z_image_denoise.py @@ -49,7 +49,7 @@ from invokeai.backend.z_image.z_image_transformer_patch import patch_transformer "z_image_denoise", title="Denoise - Z-Image", tags=["image", "z-image"], - category="image", + category="latents", version="1.5.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_image_to_latents.py b/invokeai/app/invocations/z_image_image_to_latents.py index 5a70fdba13..263346e296 100644 --- a/invokeai/app/invocations/z_image_image_to_latents.py +++ b/invokeai/app/invocations/z_image_image_to_latents.py @@ -30,7 +30,7 @@ ZImageVAE = Union[AutoencoderKL, FluxAutoEncoder] "z_image_i2l", title="Image to Latents - Z-Image", tags=["image", "latents", "vae", "i2l", "z-image"], - category="image", + category="latents", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_seed_variance_enhancer.py b/invokeai/app/invocations/z_image_seed_variance_enhancer.py index b24002e971..72819a966a 100644 --- a/invokeai/app/invocations/z_image_seed_variance_enhancer.py +++ b/invokeai/app/invocations/z_image_seed_variance_enhancer.py @@ -19,7 +19,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( "z_image_seed_variance_enhancer", title="Seed Variance Enhancer - Z-Image", tags=["conditioning", "z-image", "variance", "seed"], - category="conditioning", + category="prompt", version="1.0.0", classification=Classification.Prototype, ) diff --git a/invokeai/app/invocations/z_image_text_encoder.py b/invokeai/app/invocations/z_image_text_encoder.py index c3405d6dc8..71af6085d0 100644 --- a/invokeai/app/invocations/z_image_text_encoder.py +++ b/invokeai/app/invocations/z_image_text_encoder.py @@ -34,7 +34,7 @@ Z_IMAGE_MAX_SEQ_LEN = 512 "z_image_text_encoder", title="Prompt - Z-Image", tags=["prompt", "conditioning", "z-image"], - category="conditioning", + category="prompt", version="1.1.0", classification=Classification.Prototype, ) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 0e2f982ea6..6d79635522 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -212,6 +212,7 @@ "copy": "Copy", "copyError": "$t(gallery.copy) Error", "clipboard": "Clipboard", + "collapseAll": "Collapse All", "crop": "Crop", "on": "On", "off": "Off", @@ -239,6 +240,7 @@ "error": "Error", "error_withCount_one": "{{count}} error", "error_withCount_other": "{{count}} errors", + "expandAll": "Expand All", "model_withCount_one": "{{count}} model", "model_withCount_other": "{{count}} models", "file": "File", @@ -1383,6 +1385,8 @@ "fullyContainNodesHelp": "Nodes must be fully inside the selection box to be selected", "showEdgeLabels": "Show Edge Labels", "showEdgeLabelsHelp": "Show labels on edges, indicating the connected nodes", + "groupNodesByCategory": "Group Nodes by Category", + "groupNodesByCategoryHelp": "Group nodes by category in the add node dialog", "hideLegendNodes": "Hide Field Type Legend", "hideMinimapnodes": "Hide MiniMap", "inputMayOnlyHaveOneConnection": "Input may only have one connection", diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx index 9feae5215a..4a72cad8bf 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodeCmdk/AddNodeCmdk.tsx @@ -1,6 +1,7 @@ import type { SystemStyleObject } from '@invoke-ai/ui-library'; import { Box, + Button, Flex, Icon, Input, @@ -17,6 +18,7 @@ import { useAppSelector, useAppStore } from 'app/store/storeHooks'; import { CommandEmpty, CommandItem, CommandList, CommandRoot } from 'cmdk'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; +import { capitalize } from 'es-toolkit'; import { memoize } from 'es-toolkit/compat'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; import { @@ -33,16 +35,24 @@ import { findUnoccupiedPosition } from 'features/nodes/store/util/findUnoccupied import { getFirstValidConnection } from 'features/nodes/store/util/getFirstValidConnection'; import { connectionToEdge } from 'features/nodes/store/util/reactFlowUtil'; import { validateConnectionTypes } from 'features/nodes/store/util/validateConnectionTypes'; +import { selectShouldGroupNodesByCategory } from 'features/nodes/store/workflowSettingsSlice'; import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { useRegisteredHotkeys } from 'features/system/components/HotkeysModal/useHotkeyData'; import { toast } from 'features/toast/toast'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; import { computed } from 'nanostores'; -import type { ChangeEvent } from 'react'; +import type { ChangeEvent, Dispatch, SetStateAction } from 'react'; import { memo, useCallback, useMemo, useRef, useState } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiCircuitryBold, PiFlaskBold, PiHammerBold, PiLightningFill } from 'react-icons/pi'; +import { + PiCaretDownBold, + PiCaretRightBold, + PiCircuitryBold, + PiFlaskBold, + PiHammerBold, + PiLightningFill, +} from 'react-icons/pi'; import type { S } from 'services/api/types'; import { objectEntries } from 'tsafe'; import { useDebounce } from 'use-debounce'; @@ -171,15 +181,36 @@ export const AddNodeCmdk = memo(() => { const onClose = useCallback(() => { close(); setSearchTerm(''); + setExpandedCategories(new Set()); $pendingConnection.set(null); }, [close]); + const [expandedCategories, setExpandedCategories] = useState>(new Set()); + + const toggleCategory = useCallback((category: string) => { + setExpandedCategories((prev) => { + const next = new Set(prev); + if (next.has(category)) { + next.delete(category); + } else { + next.add(category); + } + return next; + }); + }, []); + const onSelect = useCallback( (value: string) => { + // Category headers have a special prefix + if (value.startsWith('__category__:')) { + const category = value.slice('__category__:'.length); + toggleCategory(category); + return; + } addNode(value); onClose(); }, - [addNode, onClose] + [addNode, onClose, toggleCategory] ); return ( @@ -204,7 +235,12 @@ export const AddNodeCmdk = memo(() => { /> - + @@ -230,6 +266,7 @@ type NodeCommandItemData = { description: string; classification: S['Classification']; nodePack: string; + category: string; }; /** @@ -260,6 +297,7 @@ type FilterableItem = { tags: string[]; classification: S['Classification']; nodePack: string; + category: string; }; const filter = memoize( @@ -290,6 +328,10 @@ const filter = memoize( return true; } + if (item.category.includes(searchTerm) || regex.test(item.category)) { + return true; + } + for (const tag of item.tags) { if (tag.includes(searchTerm) || regex.test(tag)) { return true; @@ -301,112 +343,253 @@ const filter = memoize( (item: FilterableItem, searchTerm: string) => `${item.type}-${searchTerm}` ); -const NodeCommandList = memo(({ searchTerm, onSelect }: { searchTerm: string; onSelect: (value: string) => void }) => { - const { t } = useTranslation(); - const templatesArray = useStore($templatesArray); - const pendingConnection = useStore($pendingConnection); - const currentImageFilterItem = useMemo( - () => ({ - type: 'current_image', - title: t('nodes.currentImage'), - description: t('nodes.currentImageDescription'), - tags: ['progress', 'image', 'current'], - classification: 'stable', - nodePack: 'invokeai', - }), - [t] - ); - const notesFilterItem = useMemo( - () => ({ - type: 'notes', - title: t('nodes.notes'), - description: t('nodes.notesDescription'), - tags: ['notes'], - classification: 'stable', - nodePack: 'invokeai', - }), - [t] - ); +const categoryItemSx: SystemStyleObject = { + cursor: 'pointer', + userSelect: 'none', + '&[data-selected="true"]': { + bg: 'base.750', + }, +}; - const items = useMemo(() => { - // If we have a connection in progress, we need to filter the node choices - const _items: NodeCommandItemData[] = []; +const NodeCommandItem = memo( + ({ + item, + onSelect, + isGrouped, + }: { + item: NodeCommandItemData; + onSelect: (value: string) => void; + isGrouped?: boolean; + }) => ( + + + + {item.classification === 'beta' && } + {item.classification === 'prototype' && } + {item.classification === 'internal' && } + {item.classification === 'special' && } + {item.label} + + + {item.nodePack} + + + {item.description && {item.description}} + + + ) +); - if (!pendingConnection) { - for (const template of templatesArray) { - if (filter(template, searchTerm)) { - _items.push({ - label: template.title, - value: template.type, - description: template.description, - classification: template.classification, - nodePack: template.nodePack, - }); +NodeCommandItem.displayName = 'NodeCommandItem'; + +const NodeCommandList = memo( + ({ + searchTerm, + onSelect, + expandedCategories, + setExpandedCategories, + }: { + searchTerm: string; + onSelect: (value: string) => void; + expandedCategories: Set; + setExpandedCategories: Dispatch>>; + }) => { + const { t } = useTranslation(); + const templatesArray = useStore($templatesArray); + const pendingConnection = useStore($pendingConnection); + const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory); + const currentImageFilterItem = useMemo( + () => ({ + type: 'current_image', + title: t('nodes.currentImage'), + description: t('nodes.currentImageDescription'), + tags: ['progress', 'image', 'current'], + classification: 'stable', + nodePack: 'invokeai', + category: 'image', + }), + [t] + ); + const notesFilterItem = useMemo( + () => ({ + type: 'notes', + title: t('nodes.notes'), + description: t('nodes.notesDescription'), + tags: ['notes'], + classification: 'stable', + nodePack: 'invokeai', + category: 'other', + }), + [t] + ); + + const items = useMemo(() => { + // If we have a connection in progress, we need to filter the node choices + const _items: NodeCommandItemData[] = []; + + if (!pendingConnection) { + for (const template of templatesArray) { + if (filter(template, searchTerm)) { + _items.push({ + label: template.title, + value: template.type, + description: template.description, + classification: template.classification, + nodePack: template.nodePack, + category: template.category, + }); + } } - } - for (const item of [currentImageFilterItem, notesFilterItem]) { - if (filter(item, searchTerm)) { - _items.push({ - label: item.title, - value: item.type, - description: item.description, - classification: item.classification, - nodePack: item.nodePack, - }); + for (const item of [currentImageFilterItem, notesFilterItem]) { + if (filter(item, searchTerm)) { + _items.push({ + label: item.title, + value: item.type, + description: item.description, + classification: item.classification, + nodePack: item.nodePack, + category: item.category, + }); + } } - } - } else { - for (const template of templatesArray) { - if (filter(template, searchTerm)) { - const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; + } else { + for (const template of templatesArray) { + if (filter(template, searchTerm)) { + const candidateFields = pendingConnection.handleType === 'source' ? template.inputs : template.outputs; - for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) { - const sourceType = - pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; - const targetType = - pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; + for (const [_fieldName, fieldTemplate] of objectEntries(candidateFields)) { + const sourceType = + pendingConnection.handleType === 'source' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; + const targetType = + pendingConnection.handleType === 'target' ? pendingConnection.fieldTemplate.type : fieldTemplate.type; - if (validateConnectionTypes(sourceType, targetType)) { - _items.push({ - label: template.title, - value: template.type, - description: template.description, - classification: template.classification, - nodePack: template.nodePack, - }); - break; + if (validateConnectionTypes(sourceType, targetType)) { + _items.push({ + label: template.title, + value: template.type, + description: template.description, + classification: template.classification, + nodePack: template.nodePack, + category: template.category, + }); + break; + } } } } } + + // Sort exact title matches to the top when searching + if (searchTerm) { + const lowerSearch = searchTerm.toLowerCase(); + _items.sort((a, b) => { + const aExact = a.label.toLowerCase() === lowerSearch; + const bExact = b.label.toLowerCase() === lowerSearch; + if (aExact && !bExact) { + return -1; + } + if (!aExact && bExact) { + return 1; + } + return 0; + }); + } + + return _items; + }, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]); + + const groupedItems = useMemo(() => { + const groups: Record = {}; + for (const item of items) { + const cat = item.category; + if (!groups[cat]) { + groups[cat] = []; + } + groups[cat].push(item); + } + // Sort categories alphabetically, but put "other" last. + // When searching, prioritize categories that contain an exact title match. + const lowerSearch = searchTerm.toLowerCase(); + return Object.entries(groups).sort(([a, aItems], [b, bItems]) => { + if (searchTerm) { + const aHasExact = aItems.some((item) => item.label.toLowerCase() === lowerSearch); + const bHasExact = bItems.some((item) => item.label.toLowerCase() === lowerSearch); + if (aHasExact && !bHasExact) { + return -1; + } + if (!aHasExact && bHasExact) { + return 1; + } + } + if (a === 'other') { + return 1; + } + if (b === 'other') { + return -1; + } + return a.localeCompare(b); + }); + }, [items, searchTerm]); + + // When searching, auto-expand all categories; when not searching, use manual state + const isSearching = searchTerm.length > 0; + + const expandAll = useCallback(() => { + setExpandedCategories(new Set(groupedItems.map(([cat]) => cat))); + }, [groupedItems, setExpandedCategories]); + + const collapseAll = useCallback(() => { + setExpandedCategories(new Set()); + }, [setExpandedCategories]); + + if (!shouldGroupNodesByCategory) { + return ( + <> + {items.map((item) => ( + + ))} + + ); } - return _items; - }, [pendingConnection, templatesArray, searchTerm, currentImageFilterItem, notesFilterItem]); - - return ( - <> - {items.map((item) => ( - - - - {item.classification === 'beta' && } - {item.classification === 'prototype' && } - {item.classification === 'internal' && } - {item.classification === 'special' && } - {item.label} - - - {item.nodePack} - - - {item.description && {item.description}} + return ( + <> + {!isSearching && ( + + + - - ))} - - ); -}); + )} + {groupedItems.map(([category, categoryItems]) => { + const isExpanded = isSearching || expandedCategories.has(category); + return ( + + + + + + {capitalize(category)} + + + ({categoryItems.length}) + + + + {isExpanded && + categoryItems.map((item) => ( + + ))} + + ); + })} + + ); + } +); NodeCommandList.displayName = 'CommandListItems'; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx index 0d3ca06c8a..2009f92144 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/panels/TopRightPanel/WorkflowEditorSettings.tsx @@ -24,11 +24,13 @@ import { selectSelectionMode, selectShouldAnimateEdges, selectShouldColorEdges, + selectShouldGroupNodesByCategory, selectShouldShouldValidateGraph, selectShouldShowEdgeLabels, selectShouldSnapToGrid, shouldAnimateEdgesChanged, shouldColorEdgesChanged, + shouldGroupNodesByCategoryChanged, shouldShowEdgeLabelsChanged, shouldSnapToGridChanged, shouldValidateGraphChanged, @@ -50,6 +52,7 @@ const WorkflowEditorSettings = () => { const shouldAnimateEdges = useAppSelector(selectShouldAnimateEdges); const shouldShowEdgeLabels = useAppSelector(selectShouldShowEdgeLabels); const shouldValidateGraph = useAppSelector(selectShouldShouldValidateGraph); + const shouldGroupNodesByCategory = useAppSelector(selectShouldGroupNodesByCategory); const handleChangeShouldValidate = useCallback( (e: ChangeEvent) => { @@ -93,6 +96,13 @@ const WorkflowEditorSettings = () => { [dispatch] ); + const handleChangeShouldGroupNodesByCategory = useCallback( + (e: ChangeEvent) => { + dispatch(shouldGroupNodesByCategoryChanged(e.target.checked)); + }, + [dispatch] + ); + const { t } = useTranslation(); return ( @@ -145,6 +155,14 @@ const WorkflowEditorSettings = () => { {t('nodes.showEdgeLabelsHelp')} + + + {t('nodes.groupNodesByCategory')} + + + {t('nodes.groupNodesByCategoryHelp')} + + {t('common.advanced')} diff --git a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts index 1eb445beaf..8706e199bb 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/testUtils.ts @@ -70,6 +70,7 @@ export const add: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'math', }; export const sub: InvocationTemplate = { @@ -128,6 +129,7 @@ export const sub: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'math', }; export const collect: InvocationTemplate = { @@ -189,6 +191,7 @@ export const collect: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'collections', }; const scheduler: InvocationTemplate = { @@ -245,6 +248,7 @@ const scheduler: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'other', }; export const main_model_loader: InvocationTemplate = { @@ -313,6 +317,7 @@ export const main_model_loader: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'model', }; export const img_resize: InvocationTemplate = { @@ -457,6 +462,7 @@ export const img_resize: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'image', }; const iterate: InvocationTemplate = { @@ -526,6 +532,7 @@ const iterate: InvocationTemplate = { useCache: true, nodePack: 'invokeai', classification: 'stable', + category: 'collections', }; export const templates: Templates = { @@ -713,7 +720,6 @@ export const schema = { required: ['type', 'id'], title: 'Scheduler', description: 'Selects a scheduler.', - category: 'latents', classification: 'stable', node_pack: 'invokeai', tags: ['scheduler'], @@ -1199,6 +1205,7 @@ export const schema = { title: 'CollectInvocation', node_pack: 'invokeai', description: 'Collects values into a collection', + category: 'collections', classification: 'stable', version: '1.1.0', output: { @@ -1558,6 +1565,7 @@ export const schema = { required: ['type', 'id'], title: 'IterateInvocation', description: 'Iterates over a list of items', + category: 'collections', classification: 'stable', node_pack: 'invokeai', version: '1.1.0', diff --git a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts index 29395a1e1d..730dced1d3 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/validateConnection.test.ts @@ -16,6 +16,7 @@ const ifTemplate: InvocationTemplate = { type: 'if', version: '1.0.0', tags: [], + category: 'math', description: 'Selects between two inputs based on a boolean condition', outputType: 'if_output', inputs: { @@ -93,6 +94,7 @@ const floatOutputTemplate: InvocationTemplate = { type: 'float_output', version: '1.0.0', tags: [], + category: 'primitives', description: 'Outputs a float', outputType: 'float_output', inputs: {}, @@ -121,6 +123,7 @@ const integerCollectionOutputTemplate: InvocationTemplate = { type: 'integer_collection_output', version: '1.0.0', tags: [], + category: 'primitives', description: 'Outputs an integer collection', outputType: 'integer_collection_output', inputs: {}, diff --git a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts index 85b803acd4..7c84bd6f1e 100644 --- a/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/workflowSettingsSlice.ts @@ -30,6 +30,7 @@ const zWorkflowSettingsState = z.object({ shouldColorEdges: z.boolean(), shouldShowEdgeLabels: z.boolean(), selectionMode: zSelectionMode, + shouldGroupNodesByCategory: z.boolean(), }); export type WorkflowSettingsState = z.infer; @@ -49,6 +50,7 @@ const getInitialState = (): WorkflowSettingsState => ({ shouldShowEdgeLabels: false, nodeOpacity: 1, selectionMode: 'partial', + shouldGroupNodesByCategory: true, }); const slice = createSlice({ @@ -94,6 +96,9 @@ const slice = createSlice({ selectionModeChanged: (state, action: PayloadAction) => { state.selectionMode = action.payload ? 'full' : 'partial'; }, + shouldGroupNodesByCategoryChanged: (state, action: PayloadAction) => { + state.shouldGroupNodesByCategory = action.payload; + }, }, }); @@ -111,6 +116,7 @@ export const { shouldValidateGraphChanged, nodeOpacityChanged, selectionModeChanged, + shouldGroupNodesByCategoryChanged, } = slice.actions; export const workflowSettingsSliceConfig: SliceConfig = { @@ -123,6 +129,9 @@ export const workflowSettingsSliceConfig: SliceConfig = { if (!('_version' in state)) { state._version = 1; } + if (!('shouldGroupNodesByCategory' in state)) { + state.shouldGroupNodesByCategory = false; + } return zWorkflowSettingsState.parse(state); }, }, @@ -145,3 +154,4 @@ export const selectNodeSpacing = createWorkflowSettingsSelector((s) => s.nodeSpa export const selectLayerSpacing = createWorkflowSettingsSelector((s) => s.layerSpacing); export const selectLayoutDirection = createWorkflowSettingsSelector((s) => s.layoutDirection); export const selectNodeAlignment = createWorkflowSettingsSelector((s) => s.nodeAlignment); +export const selectShouldGroupNodesByCategory = createWorkflowSettingsSelector((s) => s.shouldGroupNodesByCategory); diff --git a/invokeai/frontend/web/src/features/nodes/types/invocation.ts b/invokeai/frontend/web/src/features/nodes/types/invocation.ts index 9e82144e4e..5d8d85dd87 100644 --- a/invokeai/frontend/web/src/features/nodes/types/invocation.ts +++ b/invokeai/frontend/web/src/features/nodes/types/invocation.ts @@ -18,6 +18,7 @@ const _zInvocationTemplate = z.object({ useCache: z.boolean(), nodePack: z.string().min(1).default('invokeai'), classification: zClassification, + category: z.string().default('other'), }); export type InvocationTemplate = z.infer; // #endregion diff --git a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts index 57cd9943c5..47be2c62ec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts +++ b/invokeai/frontend/web/src/features/nodes/util/schema/parseSchema.ts @@ -113,6 +113,7 @@ export const parseSchema = ( const version = schema.version; const nodePack = schema.node_pack; const classification = schema.classification; + const category = schema.category ?? 'other'; const inputs = reduce( schema.properties, @@ -260,6 +261,7 @@ export const parseSchema = ( useCache, nodePack, classification, + category, }; Object.assign(invocationsAccumulator, { [type]: invocation }); From acd4157bdf82208c84100adc7a6da570c85b428e Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Tue, 14 Apr 2026 03:01:57 +0200 Subject: [PATCH 5/6] feat(ui): add canvas project save/load (.invk format) (#8917) * feat(ui): add canvas project save/load (.invk format) Add ZIP-based .invk file format to save and restore the entire canvas state including all layers, masks, reference images, generation parameters, LoRAs, and embedded image files. Images are deduplicated on load - only missing images are re-uploaded from the project file. - Always clear LoRAs on project load, even when project has none - Fix jszip dependency ordering in package.json - Add useAssertSingleton to SaveCanvasProjectDialog for consistency - Add concurrency limit (max 5) for image fetch/upload requests - Remove redundant deep-clone in remapCroppableImage (mutate in-place) - Default project name to "Canvas Project" instead of empty string * Chore pnpm fix --- .../CANVAS_PROJECTS/CANVAS_PROJECTS.md | 205 +++++++++++++ docs/features/canvas_projects.md | 56 ++++ invokeai/frontend/web/package.json | 1 + invokeai/frontend/web/pnpm-lock.yaml | 75 +++++ invokeai/frontend/web/public/locales/en.json | 13 + .../app/components/GlobalModalIsolator.tsx | 4 + .../CanvasContextMenuGlobalMenuItems.tsx | 22 +- ...adCanvasProjectConfirmationAlertDialog.tsx | 69 +++++ .../components/SaveCanvasProjectDialog.tsx | 92 ++++++ .../components/Toolbar/CanvasToolbar.tsx | 2 + .../CanvasToolbarProjectMenuButton.tsx | 37 +++ .../hooks/useCanvasProjectLoad.ts | 157 ++++++++++ .../hooks/useCanvasProjectSave.ts | 116 +++++++ .../controlLayers/store/canvasSlice.ts | 31 ++ .../controlLayers/store/paramsSlice.ts | 4 + .../controlLayers/util/canvasProjectFile.ts | 287 ++++++++++++++++++ 16 files changed, 1170 insertions(+), 1 deletion(-) create mode 100644 docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md create mode 100644 docs/features/canvas_projects.md create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx create mode 100644 invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts create mode 100644 invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts diff --git a/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md new file mode 100644 index 0000000000..a05ef29492 --- /dev/null +++ b/docs/contributing/CANVAS_PROJECTS/CANVAS_PROJECTS.md @@ -0,0 +1,205 @@ +# Canvas Projects — Technical Documentation + +## Overview + +Canvas Projects provide a save/load mechanism for the entire canvas state. The feature serializes all canvas entities, generation parameters, reference images, and their associated image files into a ZIP-based `.invk` file. On load, it restores the full state, handling image deduplication and re-uploading as needed. + +## File Format + +The `.invk` file is a standard ZIP archive with the following structure: + +``` +project.invk +├── manifest.json +├── canvas_state.json +├── params.json +├── ref_images.json +├── loras.json +└── images/ + ├── {image_name_1}.png + ├── {image_name_2}.png + └── ... +``` + +### manifest.json + +Schema version and metadata. Validated on load with Zod. + +```json +{ + "version": 1, + "appVersion": "5.12.0", + "createdAt": "2026-02-26T12:00:00.000Z", + "name": "My Canvas Project" +} +``` + +| Field | Type | Description | +|---|---|---| +| `version` | `number` | Schema version, currently `1`. Used for migration logic on load. | +| `appVersion` | `string` | InvokeAI version that created the file. Informational only. | +| `createdAt` | `string` | ISO 8601 timestamp. | +| `name` | `string` | User-provided project name. Also used as the download filename. | + +### canvas_state.json + +The serialized canvas entity tree. Type: `CanvasProjectState`. + +```typescript +type CanvasProjectState = { + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; +}; +``` + +Each entity contains its full state including all canvas objects (brush lines, eraser lines, rect shapes, images). Image objects reference files by `image_name` which correspond to files in the `images/` folder. + +### params.json + +The complete generation parameters state (`ParamsState`). Optional on load (older files may not have it). This includes all fields from the params Redux slice: + +- Prompts (positive, negative, prompt history) +- Core generation settings (seed, steps, CFG scale, guidance, scheduler, iterations) +- Model selections (main model, VAE, FLUX VAE, T5 encoder, CLIP embed models, refiner, Z-Image models, Klein models) +- Dimensions (width, height, aspect ratio) +- Img2img strength +- Infill settings (method, tile size, patchmatch downscale, color) +- Canvas coherence settings (mode, edge size, min denoise) +- Refiner parameters (steps, CFG scale, scheduler, aesthetic scores, start) +- FLUX-specific settings (scheduler, DyPE preset/scale/exponent) +- Z-Image-specific settings (scheduler, seed variance) +- Upscale settings (scheduler, CFG scale) +- Seamless tiling, mask blur, CLIP skip, VAE precision, CPU noise, color compensation + +### ref_images.json + +Global reference image entities (`RefImageState[]`). These are IP-Adapter / FLUX Redux configs with `CroppableImageWithDims` containing both original and cropped image references. Optional on load. + +### loras.json + +Array of LoRA configurations (`LoRA[]`). Each entry contains: + +```typescript +type LoRA = { + id: string; + isEnabled: boolean; + model: ModelIdentifierField; + weight: number; +}; +``` + +Optional on load. Like models, LoRA identifiers are stored as-is — if a LoRA is not installed when loading, the entry is restored but may not be usable. + +### images/ + +All image files referenced anywhere in the state. Keyed by their original `image_name`. On save, each image is fetched from the backend via `GET /api/v1/images/i/{name}/full` and stored as-is. + +## Key Source Files + +| File | Purpose | +|---|---| +| `features/controlLayers/util/canvasProjectFile.ts` | Types, constants, image name collection, remapping, existence checking | +| `features/controlLayers/hooks/useCanvasProjectSave.ts` | Save hook — collects Redux state, fetches images, builds ZIP | +| `features/controlLayers/hooks/useCanvasProjectLoad.ts` | Load hook — parses ZIP, deduplicates images, dispatches state | +| `features/controlLayers/components/SaveCanvasProjectDialog.tsx` | Save name dialog + `useSaveCanvasProjectWithDialog` hook | +| `features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx` | Load confirmation dialog + `useLoadCanvasProjectWithDialog` hook | +| `features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx` | Toolbar dropdown UI | +| `features/controlLayers/store/canvasSlice.ts` | `canvasProjectRecalled` Redux action | + +## Save Flow + +1. User clicks "Save Canvas Project" → `SaveCanvasProjectDialog` opens asking for a project name +2. On confirm, `saveCanvasProject(name)` is called +3. Read Redux state via selectors: `selectCanvasSlice()`, `selectParamsSlice()`, `selectRefImagesSlice()`, `selectLoRAsSlice()` +4. Build `CanvasProjectState` from the canvas slice; use `paramsState` directly for params +5. Walk all entities to collect every `image_name` reference via `collectImageNames()`: + - `CanvasImageState.image.image_name` in layer/mask objects + - `CroppableImageWithDims.original.image.image_name` in global ref images + - `CroppableImageWithDims.crop.image.image_name` in cropped ref images + - `ImageWithDims.image_name` in regional guidance ref images +6. Fetch each image from the backend API +7. Build ZIP with JSZip: add `manifest.json` (including `name`), `canvas_state.json`, `params.json`, `ref_images.json`, and all images into `images/` +8. Sanitize the name for filesystem use and generate blob, trigger download as `{name}.invk` + +## Load Flow + +1. User selects `.invk` file → confirmation dialog opens +2. On confirm, parse ZIP with JSZip +3. Validate manifest version via Zod schema +4. Read `canvas_state.json`, `params.json` (optional), `ref_images.json` (optional) +5. Collect all `image_name` references from the loaded state +6. **Deduplicate images**: for each referenced image, check if it exists on the server via `getImageDTOSafe(image_name)` + - Already exists → skip (no upload) + - Missing → upload from ZIP via `uploadImage()`, record `oldName → newName` mapping +7. Remap all `image_name` values in the loaded state using the mapping (only for re-uploaded images whose names changed) +8. Dispatch Redux actions: + - `canvasProjectRecalled()` — restores all canvas entities, bbox, selected/bookmarked entity + - `refImagesRecalled()` — restores global reference images + - `paramsRecalled()` — replaces the entire params state in one action + - `loraAllDeleted()` + `loraRecalled()` — restores LoRAs +9. Show success/error toast + +## Image Name Collection & Remapping + +The `canvasProjectFile.ts` utility provides two parallel sets of functions: + +**Collection** (`collectImageNames`): Walks the entire state tree and returns a `Set` of all referenced `image_name` values. This is used by both save (to know which images to fetch) and load (to know which images to check/upload). + +**Remapping** (`remapCanvasState`, `remapRefImages`): Deep-clones state objects and replaces `image_name` values using a `Map` mapping. Only images that were re-uploaded with a different name are remapped. Images that already existed on the server are left unchanged. + +Both walk the same paths through the state tree: +- Layer/mask objects → `CanvasImageState.image.image_name` +- Regional guidance ref images → `ImageWithDims.image_name` +- Global ref images → `CroppableImageWithDims.original.image.image_name` and `.crop.image.image_name` + +## Extending the Format + +### Adding new optional data (non-breaking) + +Add a new JSON file to the ZIP. No version bump needed. + +1. **Save**: Add `zip.file('new_data.json', JSON.stringify(data))` in `useCanvasProjectSave.ts` +2. **Load**: Read with `zip.file('new_data.json')` in `useCanvasProjectLoad.ts` — check for `null` so older project files without it still load +3. **Dispatch**: Add the appropriate Redux action to restore the data + +### Adding new entity types with images + +1. Extend `CanvasProjectState` type in `canvasProjectFile.ts` +2. Add collection logic in `collectImageNames()` to walk the new entity's objects +3. Add remapping logic in `remapCanvasState()` to update image names +4. Include the new entity array in both save and load hooks +5. Handle it in the `canvasProjectRecalled` reducer in `canvasSlice.ts` + +### Breaking schema changes + +1. Bump `CANVAS_PROJECT_VERSION` in `canvasProjectFile.ts` +2. Update the Zod manifest schema: `version: z.union([z.literal(1), z.literal(2)])` +3. Add migration logic in the load hook: check version, transform v1 → v2 before dispatching + +## UI Architecture + +### Save dialog + +The save flow uses a **nanostore atom** (`$isOpen`) to control the `SaveCanvasProjectDialog`: + +1. `useSaveCanvasProjectWithDialog()` — returns a callback that sets `$isOpen` to `true` +2. `SaveCanvasProjectDialog` (singleton in `GlobalModalIsolator`) — renders an `AlertDialog` with a name input +3. On save → calls `saveCanvasProject(name)` and closes the dialog +4. On cancel → closes the dialog + +### Load dialog + +The load flow uses a **nanostore atom** (`$pendingFile`) to decouple the file dialog from the confirmation dialog: + +1. `useLoadCanvasProjectWithDialog()` — opens a programmatic file input (`document.createElement('input')`) +2. On file selection → sets `$pendingFile` atom +3. `LoadCanvasProjectConfirmationAlertDialog` (singleton in `GlobalModalIsolator`) — subscribes to `$pendingFile` via `useStore()` +4. On accept → calls `loadCanvasProject(file)` and clears the atom +5. On cancel → clears the atom + +The programmatic file input approach was chosen because the context menu component uses `isLazy: true`, which unmounts the DOM tree when the menu closes — a hidden `` element inside the menu would be destroyed before the file dialog returns. diff --git a/docs/features/canvas_projects.md b/docs/features/canvas_projects.md new file mode 100644 index 0000000000..8b161c6745 --- /dev/null +++ b/docs/features/canvas_projects.md @@ -0,0 +1,56 @@ +--- +title: Canvas Projects +--- + +# :material-folder-zip: Canvas Projects + +## Save and Restore Your Canvas Work + +Canvas Projects let you save your entire canvas setup to a file and load it back later. This is useful when you want to: + +- **Switch between tasks** without losing your current canvas arrangement +- **Back up complex setups** with multiple layers, masks, and reference images +- **Share canvas layouts** with others or transfer them between machines +- **Recover from deleted images** — all images are embedded in the project file + +## What Gets Saved + +A canvas project file (`.invk`) captures everything about your current canvas session: + +- **All layers** — raster layers, control layers, inpaint masks, regional guidance +- **All drawn content** — brush strokes, pasted images, eraser marks +- **Reference images** — global IP-Adapter / FLUX Redux images with crop settings +- **Regional guidance** — per-region prompts and reference images +- **Bounding box** — position, size, aspect ratio, and scale settings +- **All generation parameters** — prompts, seed, steps, CFG scale, guidance, scheduler, model, VAE, dimensions, img2img strength, infill settings, canvas coherence, refiner settings, FLUX/Z-Image specific parameters, and more +- **LoRAs** — all added LoRA models with their weights and enabled/disabled state + +## How to Save a Project + +You can save from two places: + +1. **Toolbar** — Click the **Archive icon** in the canvas toolbar, then select **Save Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Save Canvas Project** + +A dialog will ask you to enter a **project name**. This name is used as the filename (e.g., entering "My Portrait" saves as `My Portrait.invk`) and is stored inside the project file. + +## How to Load a Project + +1. **Toolbar** — Click the **Archive icon**, then select **Load Canvas Project** +2. **Context menu** — Right-click the canvas, open the **Project** submenu, then select **Load Canvas Project** + +A file dialog will open. Select your `.invk` file. You will see a confirmation dialog warning that loading will replace your current canvas. Click **Load** to proceed. + +### What Happens on Load + +- Your current canvas is **completely replaced** — all existing layers, masks, reference images, and parameters are overwritten +- Images that are already present on your InvokeAI server are reused automatically (no duplicate uploads) +- Images that were deleted from the server are re-uploaded from the project file +- If the saved model is not installed on your system, the model identifier is still restored — you will need to select an available model manually + +## Good to Know + +- **No undo** — Loading a project replaces your canvas entirely. There is no way to undo this action, so save your current project first if you want to keep it. +- **Image deduplication** — When loading, images already on your server are not re-uploaded. Only missing images are uploaded from the project file. +- **File size** — The `.invk` file size depends on the number and resolution of images in your canvas. A project with many high-resolution layers can be large. +- **Model availability** — The project saves which model was selected, but does not include the model itself. If the model is not installed when you load the project, you will need to select a different one. diff --git a/invokeai/frontend/web/package.json b/invokeai/frontend/web/package.json index e9a896f1b4..e537362801 100644 --- a/invokeai/frontend/web/package.json +++ b/invokeai/frontend/web/package.json @@ -66,6 +66,7 @@ "i18next-http-backend": "^3.0.2", "idb-keyval": "6.2.1", "jsondiffpatch": "^0.7.3", + "jszip": "^3.10.1", "konva": "^9.3.22", "linkify-react": "^4.3.1", "linkifyjs": "^4.3.1", diff --git a/invokeai/frontend/web/pnpm-lock.yaml b/invokeai/frontend/web/pnpm-lock.yaml index 3f94ba7d69..6a2ed95ab0 100644 --- a/invokeai/frontend/web/pnpm-lock.yaml +++ b/invokeai/frontend/web/pnpm-lock.yaml @@ -86,6 +86,9 @@ importers: jsondiffpatch: specifier: ^0.7.3 version: 0.7.3 + jszip: + specifier: ^3.10.1 + version: 3.10.1 konva: specifier: ^9.3.22 version: 9.3.22 @@ -2003,6 +2006,9 @@ packages: copy-to-clipboard@3.3.3: resolution: {integrity: sha512-2KV8NhB5JqC3ky0r9PMCAZKbUHSwtEo4CwCs0KXgruG43gX5PMqDEBbVU4OUzw2MuAWUfsuFmWvEKG5QRfSnJA==} + core-util-is@1.0.3: + resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} + cosmiconfig@7.1.0: resolution: {integrity: sha512-AdmX6xUzdNASswsFtmwSt7Vj8po9IuqXm0UXz7QKPuEUmPB4XyjGfaAr2PSuELMwkRMVH1EpIkX5bTZGRB3eCA==} engines: {node: '>=10'} @@ -2672,6 +2678,9 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + immer@10.1.1: resolution: {integrity: sha512-s2MPrmjovJcoMaHtx6K11Ra7oD05NT97w1IC5zpMkT6Atjr7H8LjaDd81iIxUYpMKSRRNMJE703M1Fhr/TctHw==} @@ -2825,6 +2834,9 @@ packages: resolution: {integrity: sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==} engines: {node: '>=8'} + isarray@1.0.0: + resolution: {integrity: sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==} + isarray@2.0.5: resolution: {integrity: sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==} @@ -2916,6 +2928,9 @@ packages: resolution: {integrity: sha512-ZZow9HBI5O6EPgSJLUb8n2NKgmVWTwCvHGwFuJlMjvLFqlGG6pjirPhtdsseaLZjSibD8eegzmYpUZwoIlj2cQ==} engines: {node: '>=4.0'} + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + keyv@4.5.4: resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} @@ -2934,6 +2949,9 @@ packages: resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} engines: {node: '>= 0.8.0'} + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + lines-and-columns@1.2.4: resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} @@ -3210,6 +3228,9 @@ packages: package-json-from-dist@1.0.1: resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==} + pako@1.0.11: + resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} + pako@2.1.0: resolution: {integrity: sha512-w+eufiZ1WuJYgPXbV/PO3NCMEc3xqylkKHzp8bxp1uW4qaSNQUkwmLLEc3kKsfz8lpV1F8Ht3U1Cm+9Srog2ug==} @@ -3298,6 +3319,9 @@ packages: resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==} engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} + process-nextick-args@2.0.1: + resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==} + prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} @@ -3539,6 +3563,9 @@ packages: resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} engines: {node: '>=0.10.0'} + readable-stream@2.3.8: + resolution: {integrity: sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==} + readable-stream@3.6.2: resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} engines: {node: '>= 6'} @@ -3661,6 +3688,9 @@ packages: resolution: {integrity: sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==} engines: {node: '>=0.4'} + safe-buffer@5.1.2: + resolution: {integrity: sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==} + safe-buffer@5.2.1: resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} @@ -3718,6 +3748,9 @@ packages: resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==} engines: {node: '>= 0.4'} + setimmediate@1.0.5: + resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==} + shebang-command@2.0.0: resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} engines: {node: '>=8'} @@ -3857,6 +3890,9 @@ packages: resolution: {integrity: sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==} engines: {node: '>= 0.4'} + string_decoder@1.1.1: + resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==} + string_decoder@1.3.0: resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} @@ -6153,6 +6189,8 @@ snapshots: dependencies: toggle-selection: 1.0.6 + core-util-is@1.0.3: {} + cosmiconfig@7.1.0: dependencies: '@types/parse-json': 4.0.2 @@ -6957,6 +6995,8 @@ snapshots: ignore@7.0.5: {} + immediate@3.0.6: {} + immer@10.1.1: {} import-fresh@3.3.1: @@ -7103,6 +7143,8 @@ snapshots: dependencies: is-docker: 2.2.1 + isarray@1.0.0: {} + isarray@2.0.5: {} isexe@2.0.0: {} @@ -7192,6 +7234,13 @@ snapshots: object.assign: 4.1.7 object.values: 1.2.1 + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + keyv@4.5.4: dependencies: json-buffer: 3.0.1 @@ -7221,6 +7270,10 @@ snapshots: prelude-ls: 1.2.1 type-check: 0.4.0 + lie@3.3.0: + dependencies: + immediate: 3.0.6 + lines-and-columns@1.2.4: {} linkify-react@4.3.1(linkifyjs@4.3.1)(react@18.3.1): @@ -7510,6 +7563,8 @@ snapshots: package-json-from-dist@1.0.1: {} + pako@1.0.11: {} + pako@2.1.0: {} parent-module@1.0.1: @@ -7578,6 +7633,8 @@ snapshots: ansi-styles: 5.2.0 react-is: 17.0.2 + process-nextick-args@2.0.1: {} + prop-types@15.8.1: dependencies: loose-envify: 1.4.0 @@ -7843,6 +7900,16 @@ snapshots: dependencies: loose-envify: 1.4.0 + readable-stream@2.3.8: + dependencies: + core-util-is: 1.0.3 + inherits: 2.0.4 + isarray: 1.0.0 + process-nextick-args: 2.0.1 + safe-buffer: 5.1.2 + string_decoder: 1.1.1 + util-deprecate: 1.0.2 + readable-stream@3.6.2: dependencies: inherits: 2.0.4 @@ -7994,6 +8061,8 @@ snapshots: has-symbols: 1.1.0 isarray: 2.0.5 + safe-buffer@5.1.2: {} + safe-buffer@5.2.1: {} safe-push-apply@1.0.0: @@ -8051,6 +8120,8 @@ snapshots: es-errors: 1.3.0 es-object-atoms: 1.1.1 + setimmediate@1.0.5: {} + shebang-command@2.0.0: dependencies: shebang-regex: 3.0.0 @@ -8236,6 +8307,10 @@ snapshots: define-properties: 1.2.1 es-object-atoms: 1.1.1 + string_decoder@1.1.1: + dependencies: + safe-buffer: 5.1.2 + string_decoder@1.3.0: dependencies: safe-buffer: 5.2.1 diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6d79635522..fd6154760e 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -3019,6 +3019,19 @@ "copyCanvasToClipboard": "Copy Canvas to Clipboard", "copyBboxToClipboard": "Copy Bbox to Clipboard" }, + "canvasProject": { + "project": "Project", + "saveProject": "Save Canvas Project", + "loadProject": "Load Canvas Project", + "saveSuccess": "Project Saved", + "saveSuccessDesc": "Saved project with {{count}} images", + "saveError": "Failed to Save Project", + "loadSuccess": "Project Loaded", + "loadSuccessDesc": "Canvas state restored from project file", + "loadError": "Failed to Load Project", + "loadWarning": "Loading a project will replace your current canvas, including all layers, masks, reference images, and generation parameters. This action cannot be undone.", + "projectName": "Project Name" + }, "stagingArea": { "accept": "Accept", "discardAll": "Discard All", diff --git a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx index ef0747707f..e5ec5ccc56 100644 --- a/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx +++ b/invokeai/frontend/web/src/app/components/GlobalModalIsolator.tsx @@ -2,6 +2,8 @@ import { GlobalImageHotkeys } from 'app/components/GlobalImageHotkeys'; import ChangeBoardModal from 'features/changeBoardModal/components/ChangeBoardModal'; import { CanvasPasteModal } from 'features/controlLayers/components/CanvasPasteModal'; import { CanvasWorkflowIntegrationModal } from 'features/controlLayers/components/CanvasWorkflowIntegration/CanvasWorkflowIntegrationModal'; +import { LoadCanvasProjectConfirmationAlertDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { SaveCanvasProjectDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; import { CanvasManagerProviderGate } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import { CropImageModal } from 'features/cropper/components/CropImageModal'; import { DeleteImageModal } from 'features/deleteImageModal/components/DeleteImageModal'; @@ -54,6 +56,8 @@ export const GlobalModalIsolator = memo(() => { + + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx index ca264fa389..064378b227 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasContextMenu/CanvasContextMenuGlobalMenuItems.tsx @@ -2,6 +2,8 @@ import { Menu, MenuButton, MenuGroup, MenuItem, MenuList } from '@invoke-ai/ui-l import { SubMenuButtonContent, useSubMenu } from 'common/hooks/useSubMenu'; import { CanvasContextMenuItemsCropCanvasToBbox } from 'features/controlLayers/components/CanvasContextMenu/CanvasContextMenuItemsCropCanvasToBbox'; import { NewLayerIcon } from 'features/controlLayers/components/common/icons'; +import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; import { useCopyCanvasToClipboard } from 'features/controlLayers/hooks/copyHooks'; import { useNewControlLayerFromBbox, @@ -14,16 +16,19 @@ import { import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiCopyBold, PiFloppyDiskBold } from 'react-icons/pi'; +import { PiArchiveBold, PiCopyBold, PiFileArrowDownBold, PiFileArrowUpBold, PiFloppyDiskBold } from 'react-icons/pi'; export const CanvasContextMenuGlobalMenuItems = memo(() => { const { t } = useTranslation(); const saveSubMenu = useSubMenu(); + const projectSubMenu = useSubMenu(); const newSubMenu = useSubMenu(); const copySubMenu = useSubMenu(); const isBusy = useCanvasIsBusy(); const saveCanvasToGallery = useSaveCanvasToGallery(); const saveBboxToGallery = useSaveBboxToGallery(); + const saveCanvasProject = useSaveCanvasProjectWithDialog(); + const loadCanvasProject = useLoadCanvasProjectWithDialog(); const newRegionalReferenceImageFromBbox = useNewRegionalReferenceImageFromBbox(); const newGlobalReferenceImageFromBbox = useNewGlobalReferenceImageFromBbox(); const newRasterLayerFromBbox = useNewRasterLayerFromBbox(); @@ -50,6 +55,21 @@ export const CanvasContextMenuGlobalMenuItems = memo(() => {
+ }> + + + + + + } isDisabled={isBusy} onClick={saveCanvasProject}> + {t('controlLayers.canvasProject.saveProject')} + + } isDisabled={isBusy} onClick={loadCanvasProject}> + {t('controlLayers.canvasProject.loadProject')} + + + + }> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx b/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx new file mode 100644 index 0000000000..149d5b4f17 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog.tsx @@ -0,0 +1,69 @@ +import { ConfirmationAlertDialog, Flex, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { useCanvasProjectLoad } from 'features/controlLayers/hooks/useCanvasProjectLoad'; +import { CANVAS_PROJECT_EXTENSION } from 'features/controlLayers/util/canvasProjectFile'; +import { atom } from 'nanostores'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; + +const $pendingFile = atom(null); + +const openFileDialog = (onFileSelected: (file: File) => void) => { + const input = document.createElement('input'); + input.type = 'file'; + input.accept = CANVAS_PROJECT_EXTENSION; + input.onchange = () => { + const file = input.files?.[0]; + if (file) { + onFileSelected(file); + } + }; + input.click(); +}; + +export const useLoadCanvasProjectWithDialog = () => { + const openDialog = useCallback(() => { + openFileDialog((file) => { + $pendingFile.set(file); + }); + }, []); + + return openDialog; +}; + +export const LoadCanvasProjectConfirmationAlertDialog = memo(() => { + useAssertSingleton('LoadCanvasProjectConfirmationAlertDialog'); + const { t } = useTranslation(); + const { loadCanvasProject } = useCanvasProjectLoad(); + const pendingFile = useStore($pendingFile); + + const onClose = useCallback(() => { + $pendingFile.set(null); + }, []); + + const onAccept = useCallback(() => { + const file = $pendingFile.get(); + if (file) { + void loadCanvasProject(file); + } + $pendingFile.set(null); + }, [loadCanvasProject]); + + return ( + + + {t('controlLayers.canvasProject.loadWarning')} + + + ); +}); + +LoadCanvasProjectConfirmationAlertDialog.displayName = 'LoadCanvasProjectConfirmationAlertDialog'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx b/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx new file mode 100644 index 0000000000..bf947ba7c4 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/SaveCanvasProjectDialog.tsx @@ -0,0 +1,92 @@ +import { + AlertDialog, + AlertDialogBody, + AlertDialogContent, + AlertDialogFooter, + AlertDialogHeader, + Button, + Flex, + FormControl, + FormLabel, + Input, +} from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; +import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; +import { useCanvasProjectSave } from 'features/controlLayers/hooks/useCanvasProjectSave'; +import { atom } from 'nanostores'; +import type { ChangeEvent, RefObject } from 'react'; +import { memo, useCallback, useRef, useState } from 'react'; +import { useTranslation } from 'react-i18next'; + +const $isOpen = atom(false); + +export const useSaveCanvasProjectWithDialog = () => { + return useCallback(() => { + $isOpen.set(true); + }, []); +}; + +export const SaveCanvasProjectDialog = memo(() => { + useAssertSingleton('SaveCanvasProjectDialog'); + const isOpen = useStore($isOpen); + const cancelRef = useRef(null); + + const onClose = useCallback(() => { + $isOpen.set(false); + }, []); + + return ( + + {isOpen && } + + ); +}); + +SaveCanvasProjectDialog.displayName = 'SaveCanvasProjectDialog'; + +const Content = memo(({ cancelRef }: { cancelRef: RefObject }) => { + const { t } = useTranslation(); + const { saveCanvasProject } = useCanvasProjectSave(); + const [name, setName] = useState('Canvas Project'); + + const onChange = useCallback((e: ChangeEvent) => { + setName(e.target.value); + }, []); + + const onClose = useCallback(() => { + $isOpen.set(false); + }, []); + + const onSave = useCallback(() => { + void saveCanvasProject(name); + $isOpen.set(false); + }, [name, saveCanvasProject]); + + return ( + + + {t('controlLayers.canvasProject.saveProject')} + + + + + {t('controlLayers.canvasProject.projectName')} + + + + + + + + + + + + ); +}); + +Content.displayName = 'SaveCanvasProjectDialogContent'; diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx index faea5d98c3..fc34f4331c 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbar.tsx @@ -11,6 +11,7 @@ import { ToolWidthPicker } from 'features/controlLayers/components/Tool/ToolWidt import { CanvasToolbarFitBboxToLayersButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToLayersButton'; import { CanvasToolbarFitBboxToMasksButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarFitBboxToMasksButton'; import { CanvasToolbarNewSessionMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarNewSessionMenuButton'; +import { CanvasToolbarProjectMenuButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton'; import { CanvasToolbarRedoButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarRedoButton'; import { CanvasToolbarResetViewButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarResetViewButton'; import { CanvasToolbarSaveToGalleryButton } from 'features/controlLayers/components/Toolbar/CanvasToolbarSaveToGalleryButton'; @@ -74,6 +75,7 @@ export const CanvasToolbar = memo(() => { + diff --git a/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx new file mode 100644 index 0000000000..92cdc629ac --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/Toolbar/CanvasToolbarProjectMenuButton.tsx @@ -0,0 +1,37 @@ +import { IconButton, Menu, MenuButton, MenuItem, MenuList } from '@invoke-ai/ui-library'; +import { useLoadCanvasProjectWithDialog } from 'features/controlLayers/components/LoadCanvasProjectConfirmationAlertDialog'; +import { useSaveCanvasProjectWithDialog } from 'features/controlLayers/components/SaveCanvasProjectDialog'; +import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; +import { memo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiArchiveBold, PiFileArrowDownBold, PiFileArrowUpBold } from 'react-icons/pi'; + +export const CanvasToolbarProjectMenuButton = memo(() => { + const { t } = useTranslation(); + const isBusy = useCanvasIsBusy(); + const saveCanvasProject = useSaveCanvasProjectWithDialog(); + const loadCanvasProject = useLoadCanvasProjectWithDialog(); + + return ( + + } + variant="link" + alignSelf="stretch" + /> + + } isDisabled={isBusy} onClick={saveCanvasProject}> + {t('controlLayers.canvasProject.saveProject')} + + } isDisabled={isBusy} onClick={loadCanvasProject}> + {t('controlLayers.canvasProject.loadProject')} + + + + ); +}); + +CanvasToolbarProjectMenuButton.displayName = 'CanvasToolbarProjectMenuButton'; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts new file mode 100644 index 0000000000..21de5d3b22 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectLoad.ts @@ -0,0 +1,157 @@ +import { logger } from 'app/logging/logger'; +import { useAppDispatch } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; +import { canvasProjectRecalled } from 'features/controlLayers/store/canvasSlice'; +import { loraAllDeleted, loraRecalled } from 'features/controlLayers/store/lorasSlice'; +import { paramsRecalled } from 'features/controlLayers/store/paramsSlice'; +import { refImagesRecalled } from 'features/controlLayers/store/refImagesSlice'; +import type { LoRA, ParamsState, RefImageState } from 'features/controlLayers/store/types'; +import type { CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile'; +import { + checkExistingImages, + collectImageNames, + parseManifest, + processWithConcurrencyLimit, + remapCanvasState, + remapRefImages, +} from 'features/controlLayers/util/canvasProjectFile'; +import { toast } from 'features/toast/toast'; +import JSZip from 'jszip'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { uploadImage } from 'services/api/endpoints/images'; + +const log = logger('canvas'); + +export const useCanvasProjectLoad = () => { + const { t } = useTranslation(); + const dispatch = useAppDispatch(); + + const loadCanvasProject = useCallback( + async (file: File) => { + try { + const zip = await JSZip.loadAsync(file); + + // Validate manifest + const manifestFile = zip.file('manifest.json'); + if (!manifestFile) { + throw new Error('Invalid project file: missing manifest.json'); + } + const manifestData = JSON.parse(await manifestFile.async('string')); + parseManifest(manifestData); + + // Read state files + const canvasStateFile = zip.file('canvas_state.json'); + if (!canvasStateFile) { + throw new Error('Invalid project file: missing canvas_state.json'); + } + const canvasState: CanvasProjectState = JSON.parse(await canvasStateFile.async('string')); + + const paramsFile = zip.file('params.json'); + let projectParams: ParamsState | null = null; + if (paramsFile) { + projectParams = JSON.parse(await paramsFile.async('string')); + } + + const refImagesFile = zip.file('ref_images.json'); + let refImages: RefImageState[] = []; + if (refImagesFile) { + refImages = JSON.parse(await refImagesFile.async('string')); + } + + const lorasFile = zip.file('loras.json'); + let loras: LoRA[] = []; + if (lorasFile) { + loras = JSON.parse(await lorasFile.async('string')); + } + + // Collect all image names referenced in the state + const imageNames = collectImageNames(canvasState, refImages); + + // Check which images already exist on the server + const { missing } = await checkExistingImages(imageNames); + + // Upload missing images from the ZIP + const imageNameMapping = new Map(); + const imagesFolder = zip.folder('images'); + + if (imagesFolder && missing.size > 0) { + await processWithConcurrencyLimit(Array.from(missing), async (imageName) => { + const imageFile = imagesFolder.file(imageName); + if (!imageFile) { + log.warn(`Image ${imageName} referenced but not found in ZIP`); + return; + } + + try { + const blob = await imageFile.async('blob'); + const uploadFile = new File([blob], imageName, { type: 'image/png' }); + const imageDTO = await uploadImage({ + file: uploadFile, + image_category: 'general', + is_intermediate: false, + silent: true, + }); + + // Map old name to new name (only if different) + if (imageDTO.image_name !== imageName) { + imageNameMapping.set(imageName, imageDTO.image_name); + } + } catch (error) { + log.warn({ error: parseify(error) }, `Failed to upload image ${imageName}`); + } + }); + } + + // Remap image names in state objects + const remappedCanvasState = remapCanvasState(canvasState, imageNameMapping); + const remappedRefImages = remapRefImages(refImages, imageNameMapping); + + // Dispatch state restoration + dispatch( + canvasProjectRecalled({ + rasterLayers: remappedCanvasState.rasterLayers, + controlLayers: remappedCanvasState.controlLayers, + inpaintMasks: remappedCanvasState.inpaintMasks, + regionalGuidance: remappedCanvasState.regionalGuidance, + bbox: remappedCanvasState.bbox, + selectedEntityIdentifier: remappedCanvasState.selectedEntityIdentifier, + bookmarkedEntityIdentifier: remappedCanvasState.bookmarkedEntityIdentifier, + }) + ); + + // Restore reference images + dispatch(refImagesRecalled({ entities: remappedRefImages, replace: true })); + + // Restore generation parameters + if (projectParams) { + dispatch(paramsRecalled(projectParams)); + } + + // Restore LoRAs (always clear, even if project has none) + dispatch(loraAllDeleted()); + for (const lora of loras) { + dispatch(loraRecalled({ lora })); + } + + toast({ + id: 'CANVAS_PROJECT_LOAD_SUCCESS', + title: t('controlLayers.canvasProject.loadSuccess'), + description: t('controlLayers.canvasProject.loadSuccessDesc'), + status: 'success', + }); + } catch (error) { + log.error({ error: parseify(error) }, 'Failed to load canvas project'); + toast({ + id: 'CANVAS_PROJECT_LOAD_ERROR', + title: t('controlLayers.canvasProject.loadError'), + description: String(error), + status: 'error', + }); + } + }, + [dispatch, t] + ); + + return { loadCanvasProject }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts new file mode 100644 index 0000000000..76a91a2efa --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/hooks/useCanvasProjectSave.ts @@ -0,0 +1,116 @@ +import { logger } from 'app/logging/logger'; +import { useAppStore } from 'app/store/storeHooks'; +import { parseify } from 'common/util/serialize'; +import { downloadBlob } from 'features/controlLayers/konva/util'; +import { selectLoRAsSlice } from 'features/controlLayers/store/lorasSlice'; +import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; +import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; +import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; +import type { CanvasProjectManifest, CanvasProjectState } from 'features/controlLayers/util/canvasProjectFile'; +import { + CANVAS_PROJECT_EXTENSION, + CANVAS_PROJECT_VERSION, + collectImageNames, + processWithConcurrencyLimit, +} from 'features/controlLayers/util/canvasProjectFile'; +import { toast } from 'features/toast/toast'; +import JSZip from 'jszip'; +import { useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { useGetAppVersionQuery } from 'services/api/endpoints/appInfo'; + +const log = logger('canvas'); + +const sanitizeFileName = (name: string): string => { + // Replace characters that are invalid in filenames + return name.replace(/[<>:"/\\|?*]/g, '_').trim() || 'canvas-project'; +}; + +export const useCanvasProjectSave = () => { + const { t } = useTranslation(); + const store = useAppStore(); + const { data: appVersion } = useGetAppVersionQuery(); + + const saveCanvasProject = useCallback( + async (name: string) => { + try { + const state = store.getState(); + const canvasState = selectCanvasSlice(state); + const paramsState = selectParamsSlice(state); + const refImagesState = selectRefImagesSlice(state); + const lorasState = selectLoRAsSlice(state); + + // Build the canvas project state + const projectState: CanvasProjectState = { + rasterLayers: canvasState.rasterLayers.entities, + controlLayers: canvasState.controlLayers.entities, + inpaintMasks: canvasState.inpaintMasks.entities, + regionalGuidance: canvasState.regionalGuidance.entities, + bbox: canvasState.bbox, + selectedEntityIdentifier: canvasState.selectedEntityIdentifier, + bookmarkedEntityIdentifier: canvasState.bookmarkedEntityIdentifier, + }; + + // Collect all image names referenced in the state + const imageNames = collectImageNames(projectState, refImagesState.entities); + + // Build ZIP + const zip = new JSZip(); + + // Add manifest + const manifest: CanvasProjectManifest = { + version: CANVAS_PROJECT_VERSION, + appVersion: appVersion?.version ?? 'unknown', + createdAt: new Date().toISOString(), + name, + }; + zip.file('manifest.json', JSON.stringify(manifest, null, 2)); + + // Add state files + zip.file('canvas_state.json', JSON.stringify(projectState, null, 2)); + zip.file('params.json', JSON.stringify(paramsState, null, 2)); + zip.file('ref_images.json', JSON.stringify(refImagesState.entities, null, 2)); + zip.file('loras.json', JSON.stringify(lorasState.loras, null, 2)); + + // Fetch and add images + const imagesFolder = zip.folder('images')!; + await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => { + try { + const response = await fetch(`/api/v1/images/i/${imageName}/full`); + if (!response.ok) { + log.warn(`Failed to fetch image ${imageName}: ${response.status}`); + return; + } + const blob = await response.blob(); + imagesFolder.file(imageName, blob); + } catch (error) { + log.warn({ error: parseify(error) }, `Failed to fetch image ${imageName}`); + } + }); + + // Generate ZIP blob and trigger download + const blob = await zip.generateAsync({ type: 'blob' }); + const fileName = `${sanitizeFileName(name)}${CANVAS_PROJECT_EXTENSION}`; + downloadBlob(blob, fileName); + + toast({ + id: 'CANVAS_PROJECT_SAVE_SUCCESS', + title: t('controlLayers.canvasProject.saveSuccess'), + description: t('controlLayers.canvasProject.saveSuccessDesc', { count: imageNames.size }), + status: 'success', + }); + } catch (error) { + log.error({ error: parseify(error) }, 'Failed to save canvas project'); + toast({ + id: 'CANVAS_PROJECT_SAVE_ERROR', + title: t('controlLayers.canvasProject.saveError'), + description: String(error), + status: 'error', + }); + } + }, + [appVersion?.version, store, t] + ); + + return { saveCanvasProject }; +}; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts index d114568c3c..fd170e19e8 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasSlice.ts @@ -1736,6 +1736,36 @@ const slice = createSlice({ resetInpaintMasksHiddenIfEmpty(state); return state; }, + canvasProjectRecalled: ( + state, + action: PayloadAction<{ + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; + }> + ) => { + const { + rasterLayers, + controlLayers, + inpaintMasks, + regionalGuidance, + bbox, + selectedEntityIdentifier, + bookmarkedEntityIdentifier, + } = action.payload; + state.rasterLayers.entities = rasterLayers; + state.controlLayers.entities = controlLayers; + state.inpaintMasks.entities = inpaintMasks; + state.regionalGuidance.entities = regionalGuidance; + state.bbox = bbox; + state.selectedEntityIdentifier = selectedEntityIdentifier; + state.bookmarkedEntityIdentifier = bookmarkedEntityIdentifier; + return state; + }, canvasUndo: () => {}, canvasRedo: () => {}, canvasClearHistory: () => {}, @@ -1794,6 +1824,7 @@ const resetState = (state: CanvasState) => { export const { canvasMetadataRecalled, + canvasProjectRecalled, canvasUndo, canvasRedo, canvasClearHistory, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts index 09a16f4bca..ba78c36c3f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/paramsSlice.ts @@ -471,6 +471,9 @@ const slice = createSlice({ } }, paramsReset: (state) => resetState(state), + paramsRecalled: (_state, action: PayloadAction) => { + return action.payload; + }, }, extraReducers(builder) { // Reset params state on logout to prevent user data leakage when switching users @@ -609,6 +612,7 @@ export const { syncedToOptimalDimension, paramsReset, + paramsRecalled, animaVaeModelSelected, animaQwen3EncoderModelSelected, animaT5EncoderModelSelected, diff --git a/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts b/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts new file mode 100644 index 0000000000..97ea31e8bb --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/util/canvasProjectFile.ts @@ -0,0 +1,287 @@ +import { deepClone } from 'common/util/deepClone'; +import type { + CanvasControlLayerState, + CanvasInpaintMaskState, + CanvasObjectState, + CanvasRasterLayerState, + CanvasRegionalGuidanceState, + CanvasState, + CroppableImageWithDims, + ImageWithDims, + RefImageState, +} from 'features/controlLayers/store/types'; +import { getImageDTOSafe } from 'services/api/endpoints/images'; +import { z } from 'zod'; + +export const CANVAS_PROJECT_VERSION = 1; +export const CANVAS_PROJECT_EXTENSION = '.invk'; + +// #region Manifest + +const zCanvasProjectManifest = z.object({ + version: z.literal(CANVAS_PROJECT_VERSION), + appVersion: z.string(), + createdAt: z.string(), + name: z.string(), +}); +export type CanvasProjectManifest = z.infer; + +export const parseManifest = (data: unknown): CanvasProjectManifest => { + return zCanvasProjectManifest.parse(data); +}; + +// #endregion + +// #region Canvas Project State + +export type CanvasProjectState = { + rasterLayers: CanvasRasterLayerState[]; + controlLayers: CanvasControlLayerState[]; + inpaintMasks: CanvasInpaintMaskState[]; + regionalGuidance: CanvasRegionalGuidanceState[]; + bbox: CanvasState['bbox']; + selectedEntityIdentifier: CanvasState['selectedEntityIdentifier']; + bookmarkedEntityIdentifier: CanvasState['bookmarkedEntityIdentifier']; +}; + +// #endregion + +// #region Image Name Collection + +/** + * Collects image_name values from a CroppableImageWithDims (used by ref images). + */ +const collectFromCroppableImage = (image: CroppableImageWithDims | null, names: Set): void => { + if (!image) { + return; + } + names.add(image.original.image.image_name); + if (image.crop?.image) { + names.add(image.crop.image.image_name); + } +}; + +/** + * Collects image_name values from an ImageWithDims (used by regional guidance ref images). + */ +const collectFromImageWithDims = (image: ImageWithDims | null, names: Set): void => { + if (!image) { + return; + } + names.add(image.image_name); +}; + +/** + * Collects image_name values from canvas objects (brush lines, images, etc.). + */ +const collectFromObjects = (objects: CanvasObjectState[], names: Set): void => { + for (const obj of objects) { + if (obj.type === 'image' && 'image_name' in obj.image) { + names.add(obj.image.image_name); + } + } +}; + +/** + * Walks the entire canvas state + ref images and returns a deduplicated set of all image_name references. + */ +export const collectImageNames = (canvasState: CanvasProjectState, refImages: RefImageState[]): Set => { + const names = new Set(); + + // Raster layers + for (const layer of canvasState.rasterLayers) { + collectFromObjects(layer.objects, names); + } + + // Control layers + for (const layer of canvasState.controlLayers) { + collectFromObjects(layer.objects, names); + } + + // Inpaint masks + for (const mask of canvasState.inpaintMasks) { + collectFromObjects(mask.objects, names); + } + + // Regional guidance + for (const rg of canvasState.regionalGuidance) { + collectFromObjects(rg.objects, names); + for (const refImage of rg.referenceImages) { + if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') { + collectFromImageWithDims(refImage.config.image, names); + } + } + } + + // Global reference images + for (const refImage of refImages) { + collectFromCroppableImage(refImage.config.image, names); + } + + return names; +}; + +// #endregion + +// #region Image Name Remapping + +/** + * Remaps image_name values in a CroppableImageWithDims. + */ +/** + * Remaps image_name values in a CroppableImageWithDims in-place. + * Caller is responsible for cloning beforehand. + */ +const remapCroppableImage = (image: CroppableImageWithDims | null, mapping: Map): void => { + if (!image) { + return; + } + + const newOriginalName = mapping.get(image.original.image.image_name); + if (newOriginalName) { + image.original.image.image_name = newOriginalName; + } + + if (image.crop?.image) { + const newCropName = mapping.get(image.crop.image.image_name); + if (newCropName) { + image.crop.image.image_name = newCropName; + } + } +}; + +/** + * Remaps image_name in an ImageWithDims. + */ +const remapImageWithDims = (image: ImageWithDims | null, mapping: Map): ImageWithDims | null => { + if (!image) { + return null; + } + + const result = deepClone(image); + const newName = mapping.get(result.image_name); + if (newName) { + result.image_name = newName; + } + return result; +}; + +/** + * Remaps image_name values in canvas objects. + */ +const remapObjects = (objects: CanvasObjectState[], mapping: Map): CanvasObjectState[] => { + return objects.map((obj) => { + if (obj.type === 'image' && 'image_name' in obj.image) { + const newName = mapping.get(obj.image.image_name); + if (newName) { + return { ...obj, image: { ...obj.image, image_name: newName } }; + } + } + return obj; + }); +}; + +/** + * Deep-clones canvas state and remaps all image_name values using the provided mapping. + * Only images present in the mapping are changed (images that already existed on the server are skipped). + */ +export const remapCanvasState = (canvasState: CanvasProjectState, mapping: Map): CanvasProjectState => { + if (mapping.size === 0) { + return canvasState; + } + + const result = deepClone(canvasState); + + for (const layer of result.rasterLayers) { + layer.objects = remapObjects(layer.objects, mapping); + } + + for (const layer of result.controlLayers) { + layer.objects = remapObjects(layer.objects, mapping); + } + + for (const mask of result.inpaintMasks) { + mask.objects = remapObjects(mask.objects, mapping); + } + + for (const rg of result.regionalGuidance) { + rg.objects = remapObjects(rg.objects, mapping); + for (const refImage of rg.referenceImages) { + if (refImage.config.type === 'ip_adapter' || refImage.config.type === 'flux_redux') { + refImage.config.image = remapImageWithDims(refImage.config.image, mapping); + } + } + } + + return result; +}; + +/** + * Deep-clones ref images and remaps all image_name values using the provided mapping. + */ +export const remapRefImages = (refImages: RefImageState[], mapping: Map): RefImageState[] => { + if (mapping.size === 0) { + return refImages; + } + + return refImages.map((refImage) => { + const result = deepClone(refImage); + remapCroppableImage(result.config.image, mapping); + return result; + }); +}; + +// #endregion + +// #region Concurrency + +const MAX_CONCURRENT_REQUESTS = 5; + +/** + * Processes an array of async tasks with a concurrency limit. + */ +export const processWithConcurrencyLimit = async ( + items: T[], + fn: (item: T) => Promise, + limit: number = MAX_CONCURRENT_REQUESTS +): Promise => { + let index = 0; + + const next = async (): Promise => { + while (index < items.length) { + const currentIndex = index++; + await fn(items[currentIndex]!); + } + }; + + const workers = Array.from({ length: Math.min(limit, items.length) }, () => next()); + await Promise.all(workers); +}; + +// #endregion + +// #region Image Existence Check + +/** + * Checks which images already exist on the backend server. + * Returns sets of existing and missing image names. + */ +export const checkExistingImages = async ( + imageNames: Set +): Promise<{ existing: Set; missing: Set }> => { + const existing = new Set(); + const missing = new Set(); + + await processWithConcurrencyLimit(Array.from(imageNames), async (imageName) => { + const dto = await getImageDTOSafe(imageName); + if (dto) { + existing.add(imageName); + } else { + missing.add(imageName); + } + }); + + return { existing, missing }; +}; + +// #endregion From 1b50c1a79c450b7bbe997e217c1b96eb080d2c88 Mon Sep 17 00:00:00 2001 From: Valeri Che <38873282+DustyShoe@users.noreply.github.com> Date: Tue, 14 Apr 2026 04:10:35 +0300 Subject: [PATCH 6/6] Feat(UI): Replace prompt window resize handle with bottom edge drag handle. (#8975) * feat(ui): replace prompt window resize handle with bottom-edge drag handle * Fix: removed unused export --------- Co-authored-by: Josh Corbett --- .../components/Core/ParamNegativePrompt.tsx | 8 +- .../components/Core/ParamPositivePrompt.tsx | 9 +- .../components/Prompts/PromptResizeHandle.tsx | 135 ++++++++++++++++++ 3 files changed, 149 insertions(+), 3 deletions(-) create mode 100644 invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx index 685a0b7a2a..b0405c0ff3 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx @@ -4,6 +4,7 @@ import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize' import { negativePromptChanged, selectNegativePromptWithFallback } from 'features/controlLayers/store/paramsSlice'; import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel'; import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper'; +import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle'; import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; @@ -22,6 +23,8 @@ const persistOptions: Parameters[2] = { trackHeight: true, }; +const NEGATIVE_PROMPT_MIN_HEIGHT = 28; + export const ParamNegativePrompt = memo(() => { const dispatch = useAppDispatch(); const prompt = useAppSelector(selectNegativePromptWithFallback); @@ -70,14 +73,16 @@ export const ParamNegativePrompt = memo(() => { onChange={onChange} onKeyDown={onKeyDown} variant="darkFilled" - minH={28} borderTopWidth={24} // This prevents the prompt from being hidden behind the header paddingInlineEnd={10} paddingInlineStart={3} paddingTop={0} paddingBottom={3} + resize="none" + minH={NEGATIVE_PROMPT_MIN_HEIGHT} fontFamily="mono" fontSize="0.82rem" + sx={{ '&::-webkit-resizer': { display: 'none' } }} /> @@ -90,6 +95,7 @@ export const ParamNegativePrompt = memo(() => { label={`${t('parameters.negativePromptPlaceholder')} (${t('stylePresets.preview')})`} /> )} + ); diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index f95e950c25..89169b5ea5 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -11,6 +11,7 @@ import { ShowDynamicPromptsPreviewButton } from 'features/dynamicPrompts/compone import { NegativePromptToggleButton } from 'features/parameters/components/Core/NegativePromptToggleButton'; import { PromptLabel } from 'features/parameters/components/Prompts/PromptLabel'; import { PromptOverlayButtonWrapper } from 'features/parameters/components/Prompts/PromptOverlayButtonWrapper'; +import { PromptResizeHandle } from 'features/parameters/components/Prompts/PromptResizeHandle'; import { ViewModePrompt } from 'features/parameters/components/Prompts/ViewModePrompt'; import { AddPromptTriggerButton } from 'features/prompt/AddPromptTriggerButton'; import { PromptPopover } from 'features/prompt/PromptPopover'; @@ -35,6 +36,8 @@ const persistOptions: Parameters[2] = { initialHeight: 120, }; +const POSITIVE_PROMPT_MIN_HEIGHT = 32; + const usePromptHistory = () => { const store = useAppStore(); const history = useAppSelector(selectPositivePromptHistory); @@ -215,10 +218,11 @@ export const ParamPositivePrompt = memo(() => { paddingInlineStart={3} paddingTop={0} paddingBottom={3} - resize="vertical" - minH={32} + resize="none" + minH={POSITIVE_PROMPT_MIN_HEIGHT} fontFamily="mono" fontSize="0.82rem" + sx={{ '&::-webkit-resizer': { display: 'none' } }} /> @@ -236,6 +240,7 @@ export const ParamPositivePrompt = memo(() => { label={`${t('parameters.positivePromptPlaceholder')} (${t('stylePresets.preview')})`} /> )} + diff --git a/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx new file mode 100644 index 0000000000..0a5f211924 --- /dev/null +++ b/invokeai/frontend/web/src/features/parameters/components/Prompts/PromptResizeHandle.tsx @@ -0,0 +1,135 @@ +import { Box } from '@invoke-ai/ui-library'; +import { + memo, + type PointerEvent as ReactPointerEvent, + type RefObject, + useCallback, + useEffect, + useRef, + useState, +} from 'react'; + +type PromptResizeHandleProps = { + textareaRef: RefObject; + minHeight: number; +}; + +const PROMPT_RESIZE_HANDLE_HEIGHT_PX = 8; + +export const PromptResizeHandle = memo(({ textareaRef, minHeight }: PromptResizeHandleProps) => { + const activePointerIdRef = useRef(null); + const startHeightRef = useRef(0); + const startYRef = useRef(0); + const previousCursorRef = useRef(''); + const previousUserSelectRef = useRef(''); + const [isResizing, setIsResizing] = useState(false); + + const stopResize = useCallback(() => { + if (activePointerIdRef.current === null) { + return; + } + + activePointerIdRef.current = null; + setIsResizing(false); + document.body.style.cursor = previousCursorRef.current; + document.body.style.userSelect = previousUserSelectRef.current; + }, []); + + useEffect(() => stopResize, [stopResize]); + + const onPointerDown = useCallback( + (e: ReactPointerEvent) => { + if (e.button !== 0) { + return; + } + + const textarea = textareaRef.current; + if (!textarea) { + return; + } + + activePointerIdRef.current = e.pointerId; + startYRef.current = e.clientY; + startHeightRef.current = textarea.offsetHeight; + previousCursorRef.current = document.body.style.cursor; + previousUserSelectRef.current = document.body.style.userSelect; + + document.body.style.cursor = 'ns-resize'; + document.body.style.userSelect = 'none'; + e.currentTarget.setPointerCapture(e.pointerId); + setIsResizing(true); + e.preventDefault(); + }, + [textareaRef] + ); + + const onPointerMove = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + const textarea = textareaRef.current; + if (!textarea) { + return; + } + + const nextHeight = Math.max(minHeight, startHeightRef.current + e.clientY - startYRef.current); + textarea.style.height = `${nextHeight}px`; + e.preventDefault(); + }, + [minHeight, textareaRef] + ); + + const onPointerUp = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + if (e.currentTarget.hasPointerCapture(e.pointerId)) { + e.currentTarget.releasePointerCapture(e.pointerId); + } + + stopResize(); + }, + [stopResize] + ); + + const onPointerCancel = useCallback( + (e: ReactPointerEvent) => { + if (activePointerIdRef.current !== e.pointerId) { + return; + } + + stopResize(); + }, + [stopResize] + ); + + return ( + + ); +}); + +PromptResizeHandle.displayName = 'PromptResizeHandle';