From f6a44681a86b534a4db765956416fc88b5c64e32 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 16 May 2024 15:17:23 +1000 Subject: [PATCH] feat(ui): move invocation templates out of redux (wip) --- .../middleware/devtools/actionSanitizer.ts | 8 - .../listeners/getOpenAPISchema.ts | 6 +- .../listeners/updateAllNodesRequested.ts | 5 +- .../listeners/workflowLoadRequested.ts | 5 +- .../src/common/hooks/useIsReadyToEnqueue.ts | 352 +++++++++--------- .../flow/AddNodePopover/AddNodePopover.tsx | 23 +- .../Invocation/InvocationNodeWrapper.tsx | 14 +- .../inspector/InspectorDetailsTab.tsx | 43 ++- .../inspector/InspectorOutputsTab.tsx | 45 +-- .../inspector/InspectorTemplateTab.tsx | 31 +- .../src/features/nodes/hooks/useBuildNode.ts | 10 +- .../nodes/hooks/useGetNodesNeedUpdate.ts | 32 +- .../nodes/hooks/useIsValidConnection.ts | 7 +- .../src/features/nodes/store/nodesSlice.ts | 5 +- .../web/src/features/nodes/store/selectors.ts | 10 +- .../web/src/features/nodes/store/types.ts | 2 - .../nodes/util/workflow/graphToWorkflow.ts | 13 +- .../nodes/util/workflow/migrations.ts | 10 +- 18 files changed, 303 insertions(+), 318 deletions(-) diff --git a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts index 508109caf5..f0ea175aec 100644 --- a/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts +++ b/invokeai/frontend/web/src/app/store/middleware/devtools/actionSanitizer.ts @@ -1,7 +1,6 @@ import type { UnknownAction } from '@reduxjs/toolkit'; import { deepClone } from 'common/util/deepClone'; import { isAnyGraphBuilt } from 'features/nodes/store/actions'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; import { appInfoApi } from 'services/api/endpoints/appInfo'; import type { Graph } from 'services/api/types'; import { socketGeneratorProgress } from 'services/events/actions'; @@ -25,13 +24,6 @@ export const actionSanitizer = (action: A): A => { }; } - if (nodeTemplatesBuilt.match(action)) { - return { - ...action, - payload: '', - }; - } - if (socketGeneratorProgress.match(action)) { const sanitized = deepClone(action); if (sanitized.payload.data.progress_image) { diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts index acb2bdb698..923b2c0197 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema.ts @@ -1,7 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { parseify } from 'common/util/serialize'; -import { nodeTemplatesBuilt } from 'features/nodes/store/nodesSlice'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { parseSchema } from 'features/nodes/util/schema/parseSchema'; import { size } from 'lodash-es'; import { appInfoApi } from 'services/api/endpoints/appInfo'; @@ -9,7 +9,7 @@ import { appInfoApi } from 'services/api/endpoints/appInfo'; export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening) => { startAppListening({ matcher: appInfoApi.endpoints.getOpenAPISchema.matchFulfilled, - effect: (action, { dispatch, getState }) => { + effect: (action, { getState }) => { const log = logger('system'); const schemaJSON = action.payload; @@ -20,7 +20,7 @@ export const addGetOpenAPISchemaListener = (startAppListening: AppStartListening log.debug({ nodeTemplates: parseify(nodeTemplates) }, `Built ${size(nodeTemplates)} node templates`); - dispatch(nodeTemplatesBuilt(nodeTemplates)); + $templates.set(nodeTemplates); }, }); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts index ebd4d00901..9c2ab4278d 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/updateAllNodesRequested.ts @@ -1,7 +1,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { updateAllNodesRequested } from 'features/nodes/store/actions'; -import { nodeReplaced } from 'features/nodes/store/nodesSlice'; +import { $templates, nodeReplaced } from 'features/nodes/store/nodesSlice'; import { NodeUpdateError } from 'features/nodes/types/error'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate, updateNode } from 'features/nodes/util/node/nodeUpdate'; @@ -14,7 +14,8 @@ export const addUpdateAllNodesRequestedListener = (startAppListening: AppStartLi actionCreator: updateAllNodesRequested, effect: (action, { dispatch, getState }) => { const log = logger('nodes'); - const { nodes, templates } = getState().nodes.present; + const { nodes } = getState().nodes.present; + const templates = $templates.get(); let unableToUpdateCount = 0; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts index 5a2c270b2a..4052c75bf3 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/workflowLoadRequested.ts @@ -2,6 +2,7 @@ import { logger } from 'app/logging/logger'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { parseify } from 'common/util/serialize'; import { workflowLoaded, workflowLoadRequested } from 'features/nodes/store/actions'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { $flow } from 'features/nodes/store/reactFlowInstance'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import { validateWorkflow } from 'features/nodes/util/workflow/validateWorkflow'; @@ -14,10 +15,10 @@ import { fromZodError } from 'zod-validation-error'; export const addWorkflowLoadRequestedListener = (startAppListening: AppStartListening) => { startAppListening({ actionCreator: workflowLoadRequested, - effect: (action, { dispatch, getState }) => { + effect: (action, { dispatch }) => { const log = logger('nodes'); const { workflow, asCopy } = action.payload; - const nodeTemplates = getState().nodes.present.templates; + const nodeTemplates = $templates.get(); try { const { workflow: validatedWorkflow, warnings } = validateWorkflow(workflow, nodeTemplates); diff --git a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts index 972cb063cf..868509a59b 100644 --- a/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts +++ b/invokeai/frontend/web/src/common/hooks/useIsReadyToEnqueue.ts @@ -1,3 +1,4 @@ +import { useStore } from '@nanostores/react'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { @@ -9,14 +10,16 @@ import { selectControlLayersSlice } from 'features/controlLayers/store/controlLa import type { Layer } from 'features/controlLayers/store/types'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice'; +import type { InvocationTemplate } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectSystemSlice } from 'features/system/store/systemSlice'; import { activeTabNameSelector } from 'features/ui/store/uiSelectors'; import i18n from 'i18next'; import { forEach, upperFirst } from 'lodash-es'; +import { useMemo } from 'react'; import { getConnectedEdges } from 'reactflow'; const LAYER_TYPE_TO_TKEY: Record = { @@ -26,200 +29,205 @@ const LAYER_TYPE_TO_TKEY: Record = { regional_guidance_layer: 'controlLayers.regionalGuidance', }; -const selector = createMemoizedSelector( - [ - selectControlAdaptersSlice, - selectGenerationSlice, - selectSystemSlice, - selectNodesSlice, - selectWorkflowSettingsSlice, - selectDynamicPromptsSlice, - selectControlLayersSlice, - activeTabNameSelector, - ], - (controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => { - const { model } = generation; - const { size } = controlLayers.present; - const { positivePrompt } = controlLayers.present; +const createSelector = (templates: Record) => + createMemoizedSelector( + [ + selectControlAdaptersSlice, + selectGenerationSlice, + selectSystemSlice, + selectNodesSlice, + selectWorkflowSettingsSlice, + selectDynamicPromptsSlice, + selectControlLayersSlice, + activeTabNameSelector, + ], + (controlAdapters, generation, system, nodes, workflowSettings, dynamicPrompts, controlLayers, activeTabName) => { + const { model } = generation; + const { size } = controlLayers.present; + const { positivePrompt } = controlLayers.present; - const { isConnected } = system; + const { isConnected } = system; - const reasons: { prefix?: string; content: string }[] = []; + const reasons: { prefix?: string; content: string }[] = []; - // Cannot generate if not connected - if (!isConnected) { - reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') }); - } + // Cannot generate if not connected + if (!isConnected) { + reasons.push({ content: i18n.t('parameters.invoke.systemDisconnected') }); + } - if (activeTabName === 'workflows') { - if (workflowSettings.shouldValidateGraph) { - if (!nodes.nodes.length) { - reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') }); + if (activeTabName === 'workflows') { + if (workflowSettings.shouldValidateGraph) { + if (!nodes.nodes.length) { + reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') }); + } + + nodes.nodes.forEach((node) => { + if (!isInvocationNode(node)) { + return; + } + + const nodeTemplate = templates[node.data.type]; + + if (!nodeTemplate) { + // Node type not found + reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') }); + return; + } + + const connectedEdges = getConnectedEdges([node], nodes.edges); + + forEach(node.data.inputs, (field) => { + const fieldTemplate = nodeTemplate.inputs[field.name]; + const hasConnection = connectedEdges.some( + (edge) => edge.target === node.id && edge.targetHandle === field.name + ); + + if (!fieldTemplate) { + reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') }); + return; + } + + if (fieldTemplate.required && field.value === undefined && !hasConnection) { + reasons.push({ + content: i18n.t('parameters.invoke.missingInputForField', { + nodeLabel: node.data.label || nodeTemplate.title, + fieldLabel: field.label || fieldTemplate.title, + }), + }); + return; + } + }); + }); + } + } else { + if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) { + reasons.push({ content: i18n.t('parameters.invoke.noPrompts') }); } - nodes.nodes.forEach((node) => { - if (!isInvocationNode(node)) { - return; - } + if (!model) { + reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') }); + } - const nodeTemplate = nodes.templates[node.data.type]; - - if (!nodeTemplate) { - // Node type not found - reasons.push({ content: i18n.t('parameters.invoke.missingNodeTemplate') }); - return; - } - - const connectedEdges = getConnectedEdges([node], nodes.edges); - - forEach(node.data.inputs, (field) => { - const fieldTemplate = nodeTemplate.inputs[field.name]; - const hasConnection = connectedEdges.some( - (edge) => edge.target === node.id && edge.targetHandle === field.name - ); - - if (!fieldTemplate) { - reasons.push({ content: i18n.t('parameters.invoke.missingFieldTemplate') }); - return; - } - - if (fieldTemplate.required && field.value === undefined && !hasConnection) { - reasons.push({ - content: i18n.t('parameters.invoke.missingInputForField', { - nodeLabel: node.data.label || nodeTemplate.title, - fieldLabel: field.label || fieldTemplate.title, - }), - }); - return; - } - }); - }); - } - } else { - if (dynamicPrompts.prompts.length === 0 && getShouldProcessPrompt(positivePrompt)) { - reasons.push({ content: i18n.t('parameters.invoke.noPrompts') }); - } - - if (!model) { - reasons.push({ content: i18n.t('parameters.invoke.noModelSelected') }); - } - - if (activeTabName === 'generation') { - // Handling for generation tab - controlLayers.present.layers - .filter((l) => l.isEnabled) - .forEach((l, i) => { - const layerLiteral = i18n.t('controlLayers.layers_one'); - const layerNumber = i + 1; - const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]); - const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - if (l.type === 'control_adapter_layer') { - // Must have model - if (!l.controlAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected')); - } - // Model base must match - if (l.controlAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); - } - // Must have a control image OR, if it has a processor, it must have a processed image - if (!l.controlAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected')); - } else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed')); - } - // T2I Adapters require images have dimensions that are multiples of 64 - if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) { - problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); - } - } - - if (l.type === 'ip_adapter_layer') { - // Must have model - if (!l.ipAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); - } - // Model base must match - if (l.ipAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); - } - // Must have an image - if (!l.ipAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); - } - } - - if (l.type === 'initial_image_layer') { - // Must have an image - if (!l.image) { - problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected')); - } - } - - if (l.type === 'regional_guidance_layer') { - // Must have a region - if (l.maskObjects.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoRegion')); - } - // Must have at least 1 prompt or IP Adapter - if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); - } - l.ipAdapters.forEach((ipAdapter) => { + if (activeTabName === 'generation') { + // Handling for generation tab + controlLayers.present.layers + .filter((l) => l.isEnabled) + .forEach((l, i) => { + const layerLiteral = i18n.t('controlLayers.layers_one'); + const layerNumber = i + 1; + const layerType = i18n.t(LAYER_TYPE_TO_TKEY[l.type]); + const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; + const problems: string[] = []; + if (l.type === 'control_adapter_layer') { // Must have model - if (!ipAdapter.model) { + if (!l.controlAdapter.model) { + problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected')); + } + // Model base must match + if (l.controlAdapter.model?.base !== model?.base) { + problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); + } + // Must have a control image OR, if it has a processor, it must have a processed image + if (!l.controlAdapter.image) { + problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoImageSelected')); + } else if (l.controlAdapter.processorConfig && !l.controlAdapter.processedImage) { + problems.push(i18n.t('parameters.invoke.layer.controlAdapterImageNotProcessed')); + } + // T2I Adapters require images have dimensions that are multiples of 64 + if (l.controlAdapter.type === 't2i_adapter' && (size.width % 64 !== 0 || size.height % 64 !== 0)) { + problems.push(i18n.t('parameters.invoke.layer.t2iAdapterIncompatibleDimensions')); + } + } + + if (l.type === 'ip_adapter_layer') { + // Must have model + if (!l.ipAdapter.model) { problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); } // Model base must match - if (ipAdapter.model?.base !== model?.base) { + if (l.ipAdapter.model?.base !== model?.base) { problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); } // Must have an image - if (!ipAdapter.image) { + if (!l.ipAdapter.image) { problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); } - }); - } + } - if (problems.length) { - const content = upperFirst(problems.join(', ')); - reasons.push({ prefix, content }); - } - }); - } else { - // Handling for all other tabs - selectControlAdapterAll(controlAdapters) - .filter((ca) => ca.isEnabled) - .forEach((ca, i) => { - if (!ca.isEnabled) { - return; - } + if (l.type === 'initial_image_layer') { + // Must have an image + if (!l.image) { + problems.push(i18n.t('parameters.invoke.layer.initialImageNoImageSelected')); + } + } - if (!ca.model) { - reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) }); - } else if (ca.model.base !== model?.base) { - // This should never happen, just a sanity check - reasons.push({ - content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }), - }); - } + if (l.type === 'regional_guidance_layer') { + // Must have a region + if (l.maskObjects.length === 0) { + problems.push(i18n.t('parameters.invoke.layer.rgNoRegion')); + } + // Must have at least 1 prompt or IP Adapter + if (l.positivePrompt === null && l.negativePrompt === null && l.ipAdapters.length === 0) { + problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); + } + l.ipAdapters.forEach((ipAdapter) => { + // Must have model + if (!ipAdapter.model) { + problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); + } + // Model base must match + if (ipAdapter.model?.base !== model?.base) { + problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); + } + // Must have an image + if (!ipAdapter.image) { + problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); + } + }); + } - if ( - !ca.controlImage || - (isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none') - ) { - reasons.push({ content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }) }); - } - }); + if (problems.length) { + const content = upperFirst(problems.join(', ')); + reasons.push({ prefix, content }); + } + }); + } else { + // Handling for all other tabs + selectControlAdapterAll(controlAdapters) + .filter((ca) => ca.isEnabled) + .forEach((ca, i) => { + if (!ca.isEnabled) { + return; + } + + if (!ca.model) { + reasons.push({ content: i18n.t('parameters.invoke.noModelForControlAdapter', { number: i + 1 }) }); + } else if (ca.model.base !== model?.base) { + // This should never happen, just a sanity check + reasons.push({ + content: i18n.t('parameters.invoke.incompatibleBaseModelForControlAdapter', { number: i + 1 }), + }); + } + + if ( + !ca.controlImage || + (isControlNetOrT2IAdapter(ca) && !ca.processedControlImage && ca.processorType !== 'none') + ) { + reasons.push({ + content: i18n.t('parameters.invoke.noControlImageForControlAdapter', { number: i + 1 }), + }); + } + }); + } } - } - return { isReady: !reasons.length, reasons }; - } -); + return { isReady: !reasons.length, reasons }; + } + ); export const useIsReadyToEnqueue = () => { + const templates = useStore($templates); + const selector = useMemo(() => createSelector(templates), [templates]); const value = useAppSelector(selector); return value; }; diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx index 6cfc95e311..6d33905f4c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/AddNodePopover/AddNodePopover.tsx @@ -2,21 +2,16 @@ import 'reactflow/dist/style.css'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library'; import { Combobox, Flex, Popover, PopoverAnchor, PopoverBody, PopoverContent } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { useAppToaster } from 'app/components/Toaster'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import type { SelectInstance } from 'chakra-react-select'; import { useBuildNode } from 'features/nodes/hooks/useBuildNode'; -import { - addNodePopoverClosed, - addNodePopoverOpened, - nodeAdded, - selectNodesSlice, -} from 'features/nodes/store/nodesSlice'; +import { $templates, addNodePopoverClosed, addNodePopoverOpened, nodeAdded } from 'features/nodes/store/nodesSlice'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import { filter, map, memoize, some } from 'lodash-es'; import type { KeyboardEventHandler } from 'react'; -import { memo, useCallback, useRef } from 'react'; +import { memo, useCallback, useMemo, useRef } from 'react'; import { flushSync } from 'react-dom'; import { useHotkeys } from 'react-hotkeys-hook'; import type { HotkeyCallback } from 'react-hotkeys-hook/dist/types'; @@ -54,14 +49,15 @@ const AddNodePopover = () => { const { t } = useTranslation(); const selectRef = useRef | null>(null); const inputRef = useRef(null); + const templates = useStore($templates); const fieldFilter = useAppSelector((s) => s.nodes.present.connectionStartFieldType); const handleFilter = useAppSelector((s) => s.nodes.present.connectionStartParams?.handleType); - const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { + const options = useMemo(() => { // If we have a connection in progress, we need to filter the node choices const filteredNodeTemplates = fieldFilter - ? filter(nodes.templates, (template) => { + ? filter(templates, (template) => { const handles = handleFilter === 'source' ? template.inputs : template.outputs; return some(handles, (handle) => { @@ -71,7 +67,7 @@ const AddNodePopover = () => { return validateSourceAndTargetTypes(sourceType, targetType); }); }) - : map(nodes.templates); + : map(templates); const options: ComboboxOption[] = map(filteredNodeTemplates, (template) => { return { @@ -101,10 +97,9 @@ const AddNodePopover = () => { options.sort((a, b) => a.label.localeCompare(b.label)); - return { options }; - }); + return options; + }, [fieldFilter, handleFilter, t, templates]); - const { options } = useAppSelector(selector); const isOpen = useAppSelector((s) => s.nodes.present.isAddNodePopoverOpen); const addNode = useCallback( diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx index 0fe81c0882..cebf9cf3c5 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/InvocationNodeWrapper.tsx @@ -1,7 +1,6 @@ -import { createSelector } from '@reduxjs/toolkit'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; import InvocationNode from 'features/nodes/components/flow/nodes/Invocation/InvocationNode'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $templates } from 'features/nodes/store/nodesSlice'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; import { memo, useMemo } from 'react'; import type { NodeProps } from 'reactflow'; @@ -11,13 +10,8 @@ import InvocationNodeUnknownFallback from './InvocationNodeUnknownFallback'; const InvocationNodeWrapper = (props: NodeProps) => { const { data, selected } = props; const { id: nodeId, type, isOpen, label } = data; - - const hasTemplateSelector = useMemo( - () => createSelector(selectNodesSlice, (nodes) => Boolean(nodes.templates[type])), - [type] - ); - - const hasTemplate = useAppSelector(hasTemplateSelector); + const templates = useStore($templates); + const hasTemplate = useMemo(() => Boolean(templates[type]), [templates, type]); if (!hasTemplate) { return ( diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx index d72d2f5aa8..354a0ed179 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorDetailsTab.tsx @@ -1,36 +1,39 @@ import { Box, Flex, FormControl, FormLabel, HStack, Text } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import NotesTextarea from 'features/nodes/components/flow/nodes/Invocation/NotesTextarea'; import { useNodeNeedsUpdate } from 'features/nodes/hooks/useNodeNeedsUpdate'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import EditableNodeTitle from './details/EditableNodeTitle'; -const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { - const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; - - const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - - const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; - - if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { - return; - } - - return { - nodeId: lastSelectedNode.data.id, - nodeVersion: lastSelectedNode.data.version, - templateTitle: lastSelectedNodeTemplate.title, - }; -}); - const InspectorDetailsTab = () => { + const templates = useStore($templates); + const selector = useMemo( + () => + createMemoizedSelector(selectNodesSlice, (nodes) => { + const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; + const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); + const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined; + + if (!isInvocationNode(lastSelectedNode) || !lastSelectedNodeTemplate) { + return; + } + + return { + nodeId: lastSelectedNode.data.id, + nodeVersion: lastSelectedNode.data.version, + templateTitle: lastSelectedNodeTemplate.title, + }; + }), + [templates] + ); const data = useAppSelector(selector); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx index 978eeddd24..381a510b8b 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorOutputsTab.tsx @@ -1,38 +1,41 @@ import { Box, Flex } from '@invoke-ai/ui-library'; +import { useStore } from '@nanostores/react'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; -import { memo } from 'react'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import type { ImageOutput } from 'services/api/types'; import type { AnyResult } from 'services/events/types'; import ImageOutputPreview from './outputs/ImageOutputPreview'; -const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { - const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; - - const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - - const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; - - const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; - - if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) { - return; - } - - return { - outputs: nes.outputs, - outputType: lastSelectedNodeTemplate.outputType, - }; -}); - const InspectorOutputsTab = () => { + const templates = useStore($templates); + const selector = useMemo( + () => + createMemoizedSelector(selectNodesSlice, (nodes) => { + const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; + const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); + const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined; + + const nes = nodes.nodeExecutionStates[lastSelectedNodeId ?? '__UNKNOWN_NODE__']; + + if (!isInvocationNode(lastSelectedNode) || !nes || !lastSelectedNodeTemplate) { + return; + } + + return { + outputs: nes.outputs, + outputType: lastSelectedNodeTemplate.outputType, + }; + }), + [templates] + ); const data = useAppSelector(selector); const { t } = useTranslation(); diff --git a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx index ea6e8ed704..fbe86ba32c 100644 --- a/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/sidePanel/inspector/InspectorTemplateTab.tsx @@ -1,25 +1,26 @@ +import { useStore } from '@nanostores/react'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { useAppSelector } from 'app/store/storeHooks'; import { IAINoContentFallback } from 'common/components/IAIImageFallback'; import DataViewer from 'features/gallery/components/ImageMetadataViewer/DataViewer'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; -import { memo } from 'react'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { memo, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; -const selector = createMemoizedSelector(selectNodesSlice, (nodes) => { - const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; - - const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); - - const lastSelectedNodeTemplate = lastSelectedNode ? nodes.templates[lastSelectedNode.data.type] : undefined; - - return { - template: lastSelectedNodeTemplate, - }; -}); - const NodeTemplateInspector = () => { - const { template } = useAppSelector(selector); + const templates = useStore($templates); + const selector = useMemo( + () => + createMemoizedSelector(selectNodesSlice, (nodes) => { + const lastSelectedNodeId = nodes.selectedNodes[nodes.selectedNodes.length - 1]; + const lastSelectedNode = nodes.nodes.find((node) => node.id === lastSelectedNodeId); + const lastSelectedNodeTemplate = lastSelectedNode ? templates[lastSelectedNode.data.type] : undefined; + + return lastSelectedNodeTemplate; + }), + [templates] + ); + const template = useAppSelector(selector); const { t } = useTranslation(); if (!template) { diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts index b166b71788..4e96c219f8 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useBuildNode.ts @@ -1,4 +1,5 @@ -import { useAppSelector } from 'app/store/storeHooks'; +import { useStore } from '@nanostores/react'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { NODE_WIDTH } from 'features/nodes/types/constants'; import type { AnyNode, InvocationTemplate } from 'features/nodes/types/invocation'; import { buildCurrentImageNode } from 'features/nodes/util/node/buildCurrentImageNode'; @@ -8,8 +9,7 @@ import { useCallback } from 'react'; import { useReactFlow } from 'reactflow'; export const useBuildNode = () => { - const nodeTemplates = useAppSelector((s) => s.nodes.present.templates); - + const templates = useStore($templates); const flow = useReactFlow(); return useCallback( @@ -41,10 +41,10 @@ export const useBuildNode = () => { // TODO: Keep track of invocation types so we do not need to cast this // We know it is safe because the caller of this function gets the `type` arg from the list of invocation templates. - const template = nodeTemplates[type] as InvocationTemplate; + const template = templates[type] as InvocationTemplate; return buildInvocationNode(position, template); }, - [nodeTemplates, flow] + [templates, flow] ); }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts index 71344197d5..4adbb19c5c 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useGetNodesNeedUpdate.ts @@ -1,20 +1,26 @@ +import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { useAppSelector } from 'app/store/storeHooks'; -import { selectNodesSlice } from 'features/nodes/store/nodesSlice'; +import { $templates, selectNodesSlice } from 'features/nodes/store/nodesSlice'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { getNeedsUpdate } from 'features/nodes/util/node/nodeUpdate'; - -const selector = createSelector(selectNodesSlice, (nodes) => - nodes.nodes.filter(isInvocationNode).some((node) => { - const template = nodes.templates[node.data.type]; - if (!template) { - return false; - } - return getNeedsUpdate(node, template); - }) -); +import { useMemo } from 'react'; export const useGetNodesNeedUpdate = () => { - const getNeedsUpdate = useAppSelector(selector); - return getNeedsUpdate; + const templates = useStore($templates); + const selector = useMemo( + () => + createSelector(selectNodesSlice, (nodes) => + nodes.nodes.filter(isInvocationNode).some((node) => { + const template = templates[node.data.type]; + if (!template) { + return false; + } + return getNeedsUpdate(node, template); + }) + ), + [templates] + ); + const needsUpdate = useAppSelector(selector); + return needsUpdate; }; diff --git a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts index 7ab28f58c2..041faab149 100644 --- a/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts +++ b/invokeai/frontend/web/src/features/nodes/hooks/useIsValidConnection.ts @@ -1,5 +1,7 @@ // TODO: enable this at some point +import { useStore } from '@nanostores/react'; import { useAppSelector, useAppStore } from 'app/store/storeHooks'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { getIsGraphAcyclic } from 'features/nodes/store/util/getIsGraphAcyclic'; import { validateSourceAndTargetTypes } from 'features/nodes/store/util/validateSourceAndTargetTypes'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; @@ -13,6 +15,7 @@ import type { Connection, Node } from 'reactflow'; export const useIsValidConnection = () => { const store = useAppStore(); + const templates = useStore($templates); const shouldValidateGraph = useAppSelector((s) => s.workflowSettings.shouldValidateGraph); const isValidConnection = useCallback( ({ source, sourceHandle, target, targetHandle }: Connection): boolean => { @@ -27,7 +30,7 @@ export const useIsValidConnection = () => { } const state = store.getState(); - const { nodes, edges, templates } = state.nodes.present; + const { nodes, edges } = state.nodes.present; // Find the source and target nodes const sourceNode = nodes.find((node) => node.id === source) as Node; @@ -76,7 +79,7 @@ export const useIsValidConnection = () => { // Graphs much be acyclic (no loops!) return getIsGraphAcyclic(source, target, nodes, edges); }, - [shouldValidateGraph, store] + [shouldValidateGraph, templates, store] ); return isValidConnection; diff --git a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts index 21092bb7df..3d18a01493 100644 --- a/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts +++ b/invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts @@ -602,9 +602,6 @@ export const nodesSlice = createSlice({ state.connectionStartParams = null; state.connectionStartFieldType = null; }, - nodeTemplatesBuilt: (state, action: PayloadAction>) => { - state.templates = action.payload; - }, undo: (state) => state, redo: (state) => state, }, @@ -728,7 +725,6 @@ export const { selectionPasted, viewportChanged, edgeAdded, - nodeTemplatesBuilt, undo, redo, } = nodesSlice.actions; @@ -770,6 +766,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf( ); export const $cursorPos = atom(null); +export const $templates = atom>({}); export const $copiedNodes = atom([]); export const $copiedEdges = atom([]); diff --git a/invokeai/frontend/web/src/features/nodes/store/selectors.ts b/invokeai/frontend/web/src/features/nodes/store/selectors.ts index 90675d6270..d473005395 100644 --- a/invokeai/frontend/web/src/features/nodes/store/selectors.ts +++ b/invokeai/frontend/web/src/features/nodes/store/selectors.ts @@ -1,6 +1,6 @@ import type { NodesState } from 'features/nodes/store/types'; import type { FieldInputInstance, FieldInputTemplate, FieldOutputTemplate } from 'features/nodes/types/field'; -import type { InvocationNode, InvocationNodeData, InvocationTemplate } from 'features/nodes/types/invocation'; +import type { InvocationNode, InvocationNodeData } from 'features/nodes/types/invocation'; import { isInvocationNode } from 'features/nodes/types/invocation'; export const selectInvocationNode = (nodesSlice: NodesState, nodeId: string): InvocationNode | null => { @@ -15,14 +15,6 @@ export const selectNodeData = (nodesSlice: NodesState, nodeId: string): Invocati return selectInvocationNode(nodesSlice, nodeId)?.data ?? null; }; -export const selectNodeTemplate = (nodesSlice: NodesState, nodeId: string): InvocationTemplate | null => { - const node = selectInvocationNode(nodesSlice, nodeId); - if (!node) { - return null; - } - return nodesSlice.templates[node.data.type] ?? null; -}; - export const selectFieldInputInstance = ( nodesSlice: NodesState, nodeId: string, diff --git a/invokeai/frontend/web/src/features/nodes/store/types.ts b/invokeai/frontend/web/src/features/nodes/store/types.ts index f9c859fcc5..28b87128d0 100644 --- a/invokeai/frontend/web/src/features/nodes/store/types.ts +++ b/invokeai/frontend/web/src/features/nodes/store/types.ts @@ -2,7 +2,6 @@ import type { FieldIdentifier, FieldType, StatefulFieldValue } from 'features/no import type { AnyNode, InvocationNodeEdge, - InvocationTemplate, NodeExecutionState, } from 'features/nodes/types/invocation'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; @@ -12,7 +11,6 @@ export type NodesState = { _version: 1; nodes: AnyNode[]; edges: InvocationNodeEdge[]; - templates: Record; connectionStartParams: OnConnectStartParams | null; connectionStartFieldType: FieldType | null; connectionMade: boolean; diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts index 361e3134ae..af66d3cc6b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/graphToWorkflow.ts @@ -1,11 +1,10 @@ import * as dagre from '@dagrejs/dagre'; import { logger } from 'app/logging/logger'; -import { getStore } from 'app/store/nanostores/store'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { NODE_WIDTH } from 'features/nodes/types/constants'; import type { FieldInputInstance } from 'features/nodes/types/field'; import type { WorkflowV3 } from 'features/nodes/types/workflow'; import { buildFieldInputInstance } from 'features/nodes/util/schema/buildFieldInputInstance'; -import { t } from 'i18next'; import { forEach } from 'lodash-es'; import type { NonNullableGraph } from 'services/api/types'; import { v4 as uuidv4 } from 'uuid'; @@ -18,11 +17,7 @@ import { v4 as uuidv4 } from 'uuid'; * @returns The workflow. */ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): WorkflowV3 => { - const invocationTemplates = getStore().getState().nodes.present.templates; - - if (!invocationTemplates) { - throw new Error(t('app.storeNotInitialized')); - } + const templates = $templates.get(); // Initialize the workflow const workflow: WorkflowV3 = { @@ -44,11 +39,11 @@ export const graphToWorkflow = (graph: NonNullableGraph, autoLayout = true): Wor // Convert nodes forEach(graph.nodes, (node) => { - const template = invocationTemplates[node.type]; + const template = templates[node.type]; // Skip missing node templates - this is a best-effort if (!template) { - logger('nodes').warn(`Node type ${node.type} not found in invocationTemplates`); + logger('nodes').warn(`Node type ${node.type} not found in templates`); return; } diff --git a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts index 3f666e8771..32369b88c9 100644 --- a/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts +++ b/invokeai/frontend/web/src/features/nodes/util/workflow/migrations.ts @@ -1,5 +1,5 @@ -import { $store } from 'app/store/nanostores/store'; import { deepClone } from 'common/util/deepClone'; +import { $templates } from 'features/nodes/store/nodesSlice'; import { WorkflowMigrationError, WorkflowVersionError } from 'features/nodes/types/error'; import type { FieldType } from 'features/nodes/types/field'; import type { InvocationNodeData } from 'features/nodes/types/invocation'; @@ -33,11 +33,7 @@ const zWorkflowMetaVersion = z.object({ * - Workflow schema version bumped to 2.0.0 */ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { - const invocationTemplates = $store.get()?.getState().nodes.present.templates; - - if (!invocationTemplates) { - throw new Error(t('app.storeNotInitialized')); - } + const templates = $templates.get(); workflowToMigrate.nodes.forEach((node) => { if (node.type === 'invocation') { @@ -57,7 +53,7 @@ const migrateV1toV2 = (workflowToMigrate: WorkflowV1): WorkflowV2 => { (output.type as unknown as FieldType) = newFieldType; }); // Add node pack - const invocationTemplate = invocationTemplates[node.data.type]; + const invocationTemplate = templates[node.data.type]; const nodePack = invocationTemplate ? invocationTemplate.nodePack : t('common.unknown'); (node.data as unknown as InvocationNodeData).nodePack = nodePack;