From 5d4610d981ee352b266f2300a8f47f19cbc77cd1 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 1 Jan 2024 12:27:05 +1100 Subject: [PATCH] feat(ui): store node templates in separate slice Flattens the `nodes` slice. May offer minor perf improvements in addition to just being cleaner. --- .../middleware/devtools/actionSanitizer.ts | 2 +- .../listeners/receivedOpenAPISchema.ts | 2 +- .../listeners/socketio/socketConnected.ts | 4 +- .../listeners/updateAllNodesRequested.ts | 2 +- .../listeners/workflowLoadRequested.ts | 2 +- invokeai/frontend/web/src/app/store/store.ts | 2 + .../src/common/hooks/useIsReadyToEnqueue.ts | 11 ++- .../flow/AddNodePopover/AddNodePopover.tsx | 91 ++++++++++--------- .../Invocation/InvocationNodeWrapper.tsx | 4 +- .../inspector/InspectorDetailsTab.tsx | 39 ++++---- .../inspector/InspectorOutputsTab.tsx | 49 +++++----- .../inspector/InspectorTemplateTab.tsx | 29 +++--- .../hooks/useAnyOrDirectInputFieldNames.ts | 4 +- .../src/features/nodes/hooks/useBuildNode.ts | 11 +-- .../hooks/useConnectionInputFieldNames.ts | 4 +- .../nodes/hooks/useDoNodeVersionsMatch.ts | 4 +- .../features/nodes/hooks/useFieldInputKind.ts | 4 +- .../nodes/hooks/useFieldInputTemplate.ts | 4 +- .../nodes/hooks/useFieldOutputTemplate.ts | 4 +- .../features/nodes/hooks/useFieldTemplate.ts | 4 +- .../nodes/hooks/useFieldTemplateTitle.ts | 4 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 14 +-- .../nodes/hooks/useNodeClassification.ts | 4 +- .../nodes/hooks/useNodeNeedsUpdate.ts | 4 +- .../features/nodes/hooks/useNodeTemplate.ts | 4 +- .../nodes/hooks/useNodeTemplateByType.ts | 5 +- .../nodes/hooks/useNodeTemplateTitle.ts | 4 +- .../nodes/hooks/useOutputFieldNames.ts | 4 +- .../nodes/store/nodeTemplatesSlice.ts | 26 ++++++ .../nodes/store/nodesPersistDenylist.ts | 1 - .../src/features/nodes/store/nodesSlice.ts | 14 +-- .../web/src/features/nodes/store/types.ts | 5 +- .../nodes/util/workflow/migrations.ts | 2 +- 33 files changed, 200 insertions(+), 167 deletions(-) create mode 100644 invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index da5fb224d2..108c925ad9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,6 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; import type { Graph } from 'services/api/types'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts index 05a509b155..a11c49d069 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/receivedOpenAPISchema.ts @@ -1,6 +1,6 @@ import { logger } from 'app/logging/logger'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { receivedOpenAPISchema } from 'services/api/thunks/schema'; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts index 4039cf2406..88b9b3b5d9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/socketio/socketConnected.ts @@ -15,11 +15,11 @@ export const addSocketConnectedEventListener = () => { log.debug('Connected'); - const { nodes, config, system } = getState(); + const { nodeTemplates, config, system } = getState(); const { disabledTabs } = config; - if (!size(nodes.nodeTemplates) && !disabledTabs.includes('nodes')) { + if (!size(nodeTemplates.templates) && !disabledTabs.includes('nodes')) { dispatch(receivedOpenAPISchema()); } diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index 11bad0c221..371983c781 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -19,7 +19,7 @@ export const addUpdateAllNodesRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const nodes = getState().nodes.nodes; - const templates = getState().nodes.nodeTemplates; + const templates = getState().nodeTemplates.templates; let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 0a0803fc07..1e26ea7a0c 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -25,7 +25,7 @@ export const addWorkflowLoadRequestedListener = () => { effect: (action, { dispatch, getState }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodes.nodeTemplates; + const nodeTemplates = getState().nodeTemplates.templates; try { const { workflow: validatedWorkflow, warnings } = validateWorkflow( diff --git a/invokeai/frontend/web/src/app/store/store.ts b/invokeai/frontend/web/src/app/store/store.ts index 56e879ced3..e7540e25b4 100644 --- a/invokeai/frontend/web/src/app/store/store.ts +++ b/invokeai/frontend/web/src/app/store/store.ts @@ -14,6 +14,7 @@ import hrfReducer from 'features/hrf/store/hrfSlice'; import loraReducer from 'features/lora/store/loraSlice'; import modelmanagerReducer from 'features/modelManager/store/modelManagerSlice'; import nodesReducer from 'features/nodes/store/nodesSlice'; +import nodeTemplatesReducer from 'features/nodes/store/nodeTemplatesSlice'; import workflowReducer from 'features/nodes/store/workflowSlice'; import generationReducer from 'features/parameters/store/generationSlice'; import postprocessingReducer from 'features/parameters/store/postprocessingSlice'; @@ -42,6 +43,7 @@ const allReducers = { gallery: galleryReducer, generation: generationReducer, nodes: nodesReducer, + nodeTemplates: nodeTemplatesReducer, postprocessing: postprocessingReducer, system: systemReducer, config: configReducer, diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 4de778eac9..35d90be2f0 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -12,7 +12,14 @@ import { getConnectedEdges } from 'reactflow'; const selector = createMemoizedSelector( [stateSelector, activeTabNameSelector], ( - { controlAdapters, generation, system, nodes, dynamicPrompts }, + { + controlAdapters, + generation, + system, + nodes, + nodeTemplates, + dynamicPrompts, + }, activeTabName ) => { const { initialImage, model } = generation; @@ -41,7 +48,7 @@ const selector = createMemoizedSelector( return; } - const nodeTemplate = nodes.nodeTemplates[node.data.type]; + const nodeTemplate = nodeTemplates.templates[node.data.type]; if (!nodeTemplate) { // Node type not found diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 5139dc1ff8..8512fb1934 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -74,57 +74,60 @@ const AddNodePopover = () => { (state) => state.nodes.connectionStartParams?.handleType ); - const selector = createMemoizedSelector([stateSelector], ({ nodes }) => { - // If we have a connection in progress, we need to filter the node choices - const filteredNodeTemplates = fieldFilter - ? filter(nodes.nodeTemplates, (template) => { - const handles = - handleFilter == 'source' ? template.inputs : template.outputs; + const selector = createMemoizedSelector( + [stateSelector], + ({ nodeTemplates }) => { + // If we have a connection in progress, we need to filter the node choices + const filteredNodeTemplates = fieldFilter + ? filter(nodeTemplates.templates, (template) => { + const handles = + handleFilter == 'source' ? template.inputs : template.outputs; - return some(handles, (handle) => { - const sourceType = - handleFilter == 'source' ? fieldFilter : handle.type; - const targetType = - handleFilter == 'target' ? fieldFilter : handle.type; + return some(handles, (handle) => { + const sourceType = + handleFilter == 'source' ? fieldFilter : handle.type; + const targetType = + handleFilter == 'target' ? fieldFilter : handle.type; - return validateSourceAndTargetTypes(sourceType, targetType); - }); - }) - : map(nodes.nodeTemplates); + return validateSourceAndTargetTypes(sourceType, targetType); + }); + }) + : map(nodeTemplates.templates); - const options: InvSelectOption[] = map( - filteredNodeTemplates, - (template) => { - return { - label: template.title, - value: template.type, - description: template.description, - tags: template.tags, - }; + const options: InvSelectOption[] = map( + filteredNodeTemplates, + (template) => { + return { + label: template.title, + value: template.type, + description: template.description, + tags: template.tags, + }; + } + ); + + //We only want these nodes if we're not filtered + if (fieldFilter === null) { + options.push({ + label: t('nodes.currentImage'), + value: 'current_image', + description: t('nodes.currentImageDescription'), + tags: ['progress'], + }); + + options.push({ + label: t('nodes.notes'), + value: 'notes', + description: t('nodes.notesDescription'), + tags: ['notes'], + }); } - ); - //We only want these nodes if we're not filtered - if (fieldFilter === null) { - options.push({ - label: t('nodes.currentImage'), - value: 'current_image', - description: t('nodes.currentImageDescription'), - tags: ['progress'], - }); + options.sort((a, b) => a.label.localeCompare(b.label)); - options.push({ - label: t('nodes.notes'), - value: 'notes', - description: t('nodes.notesDescription'), - tags: ['notes'], - }); + return { options }; } - - options.sort((a, b) => a.label.localeCompare(b.label)); - - return { options }; - }); + ); const { options } = useAppSelector(selector); const isOpen = useAppSelector((state) => state.nodes.isAddNodePopoverOpen); diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx index 0132cf209c..e80a6a31bc 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx @@ -14,8 +14,8 @@ const InvocationNodeWrapper = (props: NodeProps) => { const hasTemplateSelector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => - Boolean(nodes.nodeTemplates[type]) + createMemoizedSelector(stateSelector, ({ nodeTemplates }) => + Boolean(nodeTemplates.templates[type]) ), [type] ); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index dcb0a17f43..6f0b99c401 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -14,28 +14,31 @@ import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { - const lastSelectedNodeId = - nodes.selectedNodes[nodes.selectedNodes.length - 1]; +const selector = createMemoizedSelector( + stateSelector, + ({ nodes, nodeTemplates }) => { + const lastSelectedNodeId = + nodes.selectedNodes[nodes.selectedNodes.length - 1]; - const lastSelectedNode = nodes.nodes.find( - (node) => node.id === lastSelectedNodeId - ); + const lastSelectedNode = nodes.nodes.find( + (node) => node.id === lastSelectedNodeId + ); - const lastSelectedNodeTemplate = lastSelectedNode - ? nodes.nodeTemplates[lastSelectedNode.data.type] - : undefined; + const lastSelectedNodeTemplate = lastSelectedNode + ? nodeTemplates.templates[lastSelectedNode.data.type] + : undefined; - if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { - return; + if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { + return; + } + + return { + nodeId: lastSelectedNode.data.id, + nodeVersion: lastSelectedNode.data.version, + templateTitle: lastSelectedNodeTemplate.title, + }; } - - return { - nodeId: lastSelectedNode.data.id, - nodeVersion: lastSelectedNode.data.version, - templateTitle: lastSelectedNodeTemplate.title, - }; -}); +); const InspectorDetailsTab = () => { const data = useAppSelector(selector); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index a868c328e7..1e53aaa6b1 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -13,34 +13,37 @@ import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { - const lastSelectedNodeId = - nodes.selectedNodes[nodes.selectedNodes.length - 1]; +const selector = createMemoizedSelector( + stateSelector, + ({ nodes, nodeTemplates }) => { + const lastSelectedNodeId = + nodes.selectedNodes[nodes.selectedNodes.length - 1]; - const lastSelectedNode = nodes.nodes.find( - (node) => node.id === lastSelectedNodeId - ); + const lastSelectedNode = nodes.nodes.find( + (node) => node.id === lastSelectedNodeId + ); - const lastSelectedNodeTemplate = lastSelectedNode - ? nodes.nodeTemplates[lastSelectedNode.data.type] - : undefined; + const lastSelectedNodeTemplate = lastSelectedNode + ? nodeTemplates.templates[lastSelectedNode.data.type] + : undefined; - const nes = - nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; + const nes = + nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; - if ( - !isInvocationNode(lastSelectedNode) || - !nes || - !lastSelectedNodeTemplate - ) { - return; + if ( + !isInvocationNode(lastSelectedNode) || + !nes || + !lastSelectedNodeTemplate + ) { + return; + } + + return { + outputs: nes.outputs, + outputType: lastSelectedNodeTemplate.outputType, + }; } - - return { - outputs: nes.outputs, - outputType: lastSelectedNodeTemplate.outputType, - }; -}); +); const InspectorOutputsTab = () => { const data = useAppSelector(selector); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index 8ea02c6340..ef2ad34908 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -6,22 +6,25 @@ import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataView import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(stateSelector, ({ nodes }) => { - const lastSelectedNodeId = - nodes.selectedNodes[nodes.selectedNodes.length - 1]; +const selector = createMemoizedSelector( + stateSelector, + ({ nodes, nodeTemplates }) => { + const lastSelectedNodeId = + nodes.selectedNodes[nodes.selectedNodes.length - 1]; - const lastSelectedNode = nodes.nodes.find( - (node) => node.id === lastSelectedNodeId - ); + const lastSelectedNode = nodes.nodes.find( + (node) => node.id === lastSelectedNodeId + ); - const lastSelectedNodeTemplate = lastSelectedNode - ? nodes.nodeTemplates[lastSelectedNode.data.type] - : undefined; + const lastSelectedNodeTemplate = lastSelectedNode + ? nodeTemplates.templates[lastSelectedNode.data.type] + : undefined; - return { - template: lastSelectedNodeTemplate, - }; -}); + return { + template: lastSelectedNodeTemplate, + }; + } +); const NodeTemplateInspector = () => { const { template } = useAppSelector(selector); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts index e72f310859..c5a3909e42 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useAnyOrDirectInputFieldNames.ts @@ -10,12 +10,12 @@ import { useMemo } from 'react'; export const useAnyOrDirectInputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return []; } - const nodeTemplate = nodes.nodeTemplates[node.data.type]; + const nodeTemplate = nodeTemplates.templates[node.data.type]; if (!nodeTemplate) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index ea44887d6b..88d723fff9 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -1,5 +1,3 @@ -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; -import type { RootState } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import { DRAG_HANDLE_CLASSNAME, @@ -16,17 +14,14 @@ import { useCallback } from 'react'; import type { Node } from 'reactflow'; import { useReactFlow } from 'reactflow'; -const templatesSelector = createMemoizedSelector( - [(state: RootState) => state.nodes], - (nodes) => nodes.nodeTemplates -); - export const SHARED_NODE_PROPERTIES: Partial = { dragHandle: `.${DRAG_HANDLE_CLASSNAME}`, }; export const useBuildNode = () => { - const nodeTemplates = useAppSelector(templatesSelector); + const nodeTemplates = useAppSelector( + (state) => state.nodeTemplates.templates + ); const flow = useReactFlow(); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts index bda47fd420..9091062193 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useConnectionInputFieldNames.ts @@ -10,12 +10,12 @@ import { useMemo } from 'react'; export const useConnectionInputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return []; } - const nodeTemplate = nodes.nodeTemplates[node.data.type]; + const nodeTemplate = nodeTemplates.templates[node.data.type]; if (!nodeTemplate) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts index b79f34945b..d996926fa6 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useDoNodeVersionsMatch.ts @@ -8,12 +8,12 @@ import { useMemo } from 'react'; export const useDoNodeVersionsMatch = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return false; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; if (!nodeTemplate?.version || !node.data?.version) { return false; } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts index c2d4b85023..0b20559f07 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputKind.ts @@ -7,12 +7,12 @@ import { useMemo } from 'react'; export const useFieldInputKind = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; const fieldTemplate = nodeTemplate?.inputs[fieldName]; return fieldTemplate?.input; }), diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts index 852fa5f9c4..5a381be1a8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldInputTemplate.ts @@ -7,12 +7,12 @@ import { useMemo } from 'react'; export const useFieldInputTemplate = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate?.inputs[fieldName]; }), [fieldName, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts index 852c8ab353..4ad16b4308 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldOutputTemplate.ts @@ -7,12 +7,12 @@ import { useMemo } from 'react'; export const useFieldOutputTemplate = (nodeId: string, fieldName: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate?.outputs[fieldName]; }), [fieldName, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts index e5c632a0cd..d018706f67 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplate.ts @@ -12,12 +12,12 @@ export const useFieldTemplate = ( ) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate?.[KIND_MAP[kind]][fieldName]; }), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts index 57d52bc0ef..05897a7823 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useFieldTemplateTitle.ts @@ -12,12 +12,12 @@ export const useFieldTemplateTitle = ( ) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate?.[KIND_MAP[kind]][fieldName]?.title; }), [fieldName, kind, nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index 06e8b03da8..175aa31209 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -4,19 +4,15 @@ import { useAppSelector } from 'app/store/storeHooks'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; -const selector = createMemoizedSelector(stateSelector, (state) => { - const nodes = state.nodes.nodes; - const templates = state.nodes.nodeTemplates; - - const needsUpdate = nodes.filter(isInvocationNode).some((node) => { - const template = templates[node.data.type]; +const selector = createMemoizedSelector(stateSelector, (state) => + state.nodes.nodes.filter(isInvocationNode).some((node) => { + const template = state.nodeTemplates.templates[node.data.type]; if (!template) { return false; } return getNeedsUpdate(node, template); - }); - return needsUpdate; -}); + }) +); export const useGetNodesNeedUpdate = () => { const getNeedsUpdate = useAppSelector(selector); diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts index 773f6de249..1af0c0ce1e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeClassification.ts @@ -7,12 +7,12 @@ import { useMemo } from 'react'; export const useNodeClassification = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return false; } - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate?.classification; }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts index da69150411..ae115e6e5c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeNeedsUpdate.ts @@ -8,9 +8,9 @@ import { useMemo } from 'react'; export const useNodeNeedsUpdate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); - const template = nodes.nodeTemplates[node?.data.type ?? '']; + const template = nodeTemplates.templates[node?.data.type ?? '']; if (isInvocationNode(node) && template) { return getNeedsUpdate(node, template); } diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts index c8bdca3cf6..c5ecfe0ed1 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplate.ts @@ -6,9 +6,9 @@ import { useMemo } from 'react'; export const useNodeTemplate = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); - const nodeTemplate = nodes.nodeTemplates[node?.data.type ?? '']; + const nodeTemplate = nodeTemplates.templates[node?.data.type ?? '']; return nodeTemplate; }), [nodeId] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts index fe47ab34a2..eff72add80 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateByType.ts @@ -9,9 +9,8 @@ export const useNodeTemplateByType = (type: string) => { () => createMemoizedSelector( stateSelector, - ({ nodes }): InvocationTemplate | undefined => { - const nodeTemplate = nodes.nodeTemplates[type]; - return nodeTemplate; + ({ nodeTemplates }): InvocationTemplate | undefined => { + return nodeTemplates.templates[type]; } ), [type] diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts index 37ba2243c4..7aada521c4 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useNodeTemplateTitle.ts @@ -7,13 +7,13 @@ import { useMemo } from 'react'; export const useNodeTemplateTitle = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return false; } const nodeTemplate = node - ? nodes.nodeTemplates[node.data.type] + ? nodeTemplates.templates[node.data.type] : undefined; return nodeTemplate?.title; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts index 1dc4d8862c..fe57c11e1e 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useOutputFieldNames.ts @@ -9,12 +9,12 @@ import { useMemo } from 'react'; export const useOutputFieldNames = (nodeId: string) => { const selector = useMemo( () => - createMemoizedSelector(stateSelector, ({ nodes }) => { + createMemoizedSelector(stateSelector, ({ nodes, nodeTemplates }) => { const node = nodes.nodes.find((node) => node.id === nodeId); if (!isInvocationNode(node)) { return []; } - const nodeTemplate = nodes.nodeTemplates[node.data.type]; + const nodeTemplate = nodeTemplates.templates[node.data.type]; if (!nodeTemplate) { return []; } diff --git a/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts new file mode 100644 index 0000000000..8080dd8eca --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/store/nodeTemplatesSlice.ts @@ -0,0 +1,26 @@ +import type { PayloadAction } from '@reduxjs/toolkit'; +import { createSlice } from '@reduxjs/toolkit'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; + +import type { NodeTemplatesState } from './types'; + +export const initialNodeTemplatesState: NodeTemplatesState = { + templates: {}, +}; + +const nodesTemplatesSlice = createSlice({ + name: 'nodeTemplates', + initialState: initialNodeTemplatesState, + reducers: { + nodeTemplatesBuilt: ( + state, + action: PayloadAction> + ) => { + state.templates = action.payload; + }, + }, +}); + +export const { nodeTemplatesBuilt } = nodesTemplatesSlice.actions; + +export default nodesTemplatesSlice.reducer; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts index 1a151afe92..5e3947314a 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesPersistDenylist.ts @@ -4,7 +4,6 @@ import type { NodesState } from './types'; * Nodes slice persist denylist */ export const nodesPersistDenylist: (keyof NodesState)[] = [ - 'nodeTemplates', 'connectionStartParams', 'connectionStartFieldType', 'selectedNodes', diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index fdbdaf01ef..30c8ac5691 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -1,6 +1,7 @@ import type { PayloadAction } from '@reduxjs/toolkit'; import { createSlice, isAnyOf } from '@reduxjs/toolkit'; import { workflowLoaded } from 'features/nodes/store/actions'; +import { nodeTemplatesBuilt } from 'features/nodes/store/nodeTemplatesSlice'; import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants'; import type { BoardFieldValue, @@ -41,7 +42,6 @@ import { } from 'features/nodes/types/field'; import type { AnyNode, - InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; import { @@ -97,7 +97,6 @@ const initialNodeExecutionState: Omit = { export const initialNodesState: NodesState = { nodes: [], edges: [], - nodeTemplates: {}, isReady: false, connectionStartParams: null, connectionStartFieldType: null, @@ -656,13 +655,6 @@ const nodesSlice = createSlice({ shouldShowMinimapPanelChanged: (state, action: PayloadAction) => { state.shouldShowMinimapPanel = action.payload; }, - nodeTemplatesBuilt: ( - state, - action: PayloadAction> - ) => { - state.nodeTemplates = action.payload; - state.isReady = true; - }, nodeEditorReset: (state) => { state.nodes = []; state.edges = []; @@ -893,6 +885,9 @@ const nodesSlice = createSlice({ }); } }); + builder.addCase(nodeTemplatesBuilt, (state) => { + state.isReady = true; + }); }, }); @@ -935,7 +930,6 @@ export const { nodeOpacityChanged, nodesChanged, nodesDeleted, - nodeTemplatesBuilt, nodeUseCacheChanged, notesNodeValueChanged, selectedAll, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index 850abe26cf..9ea5343118 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -16,7 +16,6 @@ import type { export type NodesState = { nodes: AnyNode[]; edges: InvocationNodeEdge[]; - nodeTemplates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; @@ -42,3 +41,7 @@ export type NodesState = { export type WorkflowsState = Omit & { isTouched: boolean; }; + +export type NodeTemplatesState = { + templates: Record; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 1f9b742d88..004c59d9e5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -33,7 +33,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodes.nodeTemplates; + const invocationTemplates = $store.get()?.getState().nodeTemplates.templates; if (!invocationTemplates) { throw new Error(t('app.storeNotInitialized'));