From 2e13bbbe1babc66119ba419bd0112e54201cf6dc Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Wed, 26 Feb 2025 09:25:56 +1000 Subject: [PATCH] refactor(ui): make all readiness checking async This fixes the broken readiness checks introduced in the previous commit. To support async batch generators, all of the validation of the generators needs to be async. This is problematic because a lot of the validation logic was in redux selectors, which are necessarily synchronous. To resolve this, the readiness checks and related logic are restructured to be run async in response to redux state changes via `useEffect` (another option is to directly subscribe to redux store). These async functions then set some react state. The checks are debounced to prevent thrashing the UI. See #7580 for more context about this issue. Other changes: - Fix a minor issue where empty collections were also checked against their min and max sizes, and errors were shown for all the checks. If a collection is empty, we don't need to do the min/max checks. If a collection is empty, we skip the other min/max checks and do not report those errors to the user. - When a field is connected, do not attempt to check its value. This fixes an issue where collection fields with a connection could erroneously appear to be invalid. - Improved error messages for batch nodes. --- invokeai/frontend/web/public/locales/en.json | 3 +- .../listeners/enqueueRequestedNodes.ts | 14 +- .../inputs/StringGeneratorFieldComponent.tsx | 6 +- .../nodes/store/util/fieldValidators.ts | 44 +++--- .../web/src/features/nodes/types/field.ts | 19 +-- .../nodes/util/node/resolveBatchValue.ts | 136 ++++++++++++++---- .../InvokeButtonTooltip.tsx | 40 ++++-- .../web/src/features/queue/store/readiness.ts | 120 ++++------------ .../src/services/api/endpoints/utilities.ts | 2 - 9 files changed, 210 insertions(+), 174 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6370d34f26..459bf17d6b 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1080,7 +1080,7 @@ "emptyBatches": "empty batches", "batchNodeNotConnected": "Batch node not connected: {{label}}", "batchNodeEmptyCollection": "Some batch nodes have empty collections", - "invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate", + "collectionEmpty": "empty collection", "collectionTooFewItems": "too few items, minimum {{minItems}}", "collectionTooManyItems": "too many items, maximum {{maxItems}}", "collectionStringTooLong": "too long, max {{maxLength}}", @@ -1090,6 +1090,7 @@ "collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)", "collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)", "collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}", + "batchNodeCollectionSizeMismatchNoGroupId": "Batch group collection size mismatch", "batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}", "noModelSelected": "No model selected", "noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index fa5fca0757..6a14028ad9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -16,13 +16,13 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = enqueueRequested.match(action) && action.payload.tabName === 'workflows', effect: async (action, { getState, dispatch }) => { const state = getState(); - const nodes = selectNodesSlice(state); + const nodesState = selectNodesSlice(state); const workflow = state.workflow; const templates = $templates.get(); const graph = buildNodesGraph(state, templates); const builtWorkflow = buildWorkflowWithValidation({ - nodes: nodes.nodes, - edges: nodes.edges, + nodes: nodesState.nodes, + edges: nodesState.edges, workflow, }); @@ -33,7 +33,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = const data: Batch['data'] = []; - const invocationNodes = nodes.nodes.filter(isInvocationNode); + const invocationNodes = nodesState.nodes.filter(isInvocationNode); const batchNodes = invocationNodes.filter(isBatchNode); // Handle zipping batch nodes. First group the batch nodes by their batch_group_id @@ -44,9 +44,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = const zippedBatchDataCollectionItems: NonNullable[number] = []; for (const node of batchNodes) { - const value = resolveBatchValue(node, invocationNodes, nodes.edges); + const value = await resolveBatchValue({ nodesState, node, dispatch }); const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value'; - const edgesFromBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle); + const edgesFromBatch = nodesState.edges.filter( + (e) => e.source === node.id && e.sourceHandle === sourceHandle + ); if (batchGroupId !== 'None') { // If this batch node has a batch_group_id, we will zip the data collection items for (const edge of edgesFromBatch) { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx index 4e57bf87ff..2c5848043e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx @@ -1,5 +1,4 @@ import { Flex, Select, Text } from '@invoke-ai/ui-library'; -import { useAppStore } from 'app/store/nanostores/store'; import { useAppDispatch } from 'app/store/storeHooks'; import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants'; import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings'; @@ -29,7 +28,6 @@ export const StringGeneratorFieldInputComponent = memo( const { nodeId, field } = props; const { t } = useTranslation(); const dispatch = useAppDispatch(); - const store = useAppStore(); const onChange = useCallback( (value: StringGeneratorFieldInputInstance['value']) => { @@ -62,14 +60,14 @@ export const StringGeneratorFieldInputComponent = memo( const resolveAndSetValuesAsString = useMemo( () => debounce(async (field: StringGeneratorFieldInputInstance) => { - const resolvedValues = await resolveStringGeneratorField(field, store); + const resolvedValues = await resolveStringGeneratorField(field, dispatch); if (resolvedValues.length === 0) { setResolvedValuesAsString(`<${t('nodes.generatorNoValues')}>`); } else { setResolvedValuesAsString(resolvedValues.join(', ')); } }, 300), - [store, t] + [dispatch, t] ); useEffect(() => { resolveAndSetValuesAsString(field); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts index 76d2962e4c..94fa97ecab 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts @@ -36,14 +36,14 @@ const validateImageFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } return reasons; @@ -60,14 +60,14 @@ const validateStringFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } for (const str of value) { @@ -93,14 +93,14 @@ const validateNumberFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } for (const num of value) { @@ -174,6 +174,8 @@ export const getFieldErrors = ( prefix, issue: t('parameters.invoke.missingInputForField'), }); + } else if (isConnected) { + // Connected fields have no value to validate - they are OK } else if ( field.value && isImageFieldCollectionInputTemplate(fieldTemplate) && diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index d405362c0f..e293b81e89 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -1,5 +1,5 @@ import { EMPTY_ARRAY } from 'app/store/constants'; -import type { AppStore } from 'app/store/store'; +import type { AppDispatch } from 'app/store/store'; import { isNil, random, trim } from 'lodash-es'; import MersenneTwister from 'mtwist'; import { utilitiesApi } from 'services/api/endpoints/utilities'; @@ -1414,10 +1414,10 @@ const getStringGeneratorDynamicPromptsCombinatorialDefaults = () => zStringGeneratorDynamicPromptsCombinatorial.parse({}); const getStringGeneratorDynamicPromptsCombinatorialValues = async ( generator: StringGeneratorDynamicPromptsCombinatorial, - store: AppStore + dispatch: AppDispatch ): Promise => { const { input, maxPrompts } = generator; - const req = store.dispatch( + const req = dispatch( utilitiesApi.endpoints.dynamicPrompts.initiate( { prompt: input, @@ -1452,10 +1452,10 @@ export type StringGeneratorDynamicPromptsRandom = z.infer zStringGeneratorDynamicPromptsRandom.parse({}); const getStringGeneratorDynamicPromptsRandomValues = async ( generator: StringGeneratorDynamicPromptsRandom, - store: AppStore + dispatch: AppDispatch ): Promise => { const { input, seed, count } = generator; - const req = store.dispatch( + const req = dispatch( utilitiesApi.endpoints.dynamicPrompts.initiate( { prompt: input, @@ -1503,7 +1503,10 @@ export const isStringGeneratorFieldInputTemplate = buildTemplateTypeGuard { +export const resolveStringGeneratorField = async ( + { value }: StringGeneratorFieldInputInstance, + dispatch: AppDispatch +) => { if (value.values) { return value.values; } @@ -1511,10 +1514,10 @@ export const resolveStringGeneratorField = async ({ value }: StringGeneratorFiel return getStringGeneratorParseStringValues(value); } if (value.type === StringGeneratorDynamicPromptsRandomType) { - return await getStringGeneratorDynamicPromptsRandomValues(value, store); + return await getStringGeneratorDynamicPromptsRandomValues(value, dispatch); } if (value.type === StringGeneratorDynamicPromptsCombinatorialType) { - return await getStringGeneratorDynamicPromptsCombinatorialValues(value, store); + return await getStringGeneratorDynamicPromptsCombinatorialValues(value, dispatch); } assert(false, 'Invalid string generator type'); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts b/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts index 52098839a2..a4cb58baa0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts @@ -1,4 +1,5 @@ -import type { AppStore } from 'app/store/store'; +import type { AppDispatch } from 'app/store/store'; +import type { NodesState } from 'features/nodes/store/types'; import { isFloatFieldCollectionInputInstance, isFloatGeneratorFieldInputInstance, @@ -11,47 +12,51 @@ import { resolveIntegerGeneratorField, resolveStringGeneratorField, } from 'features/nodes/types/field'; -import type { AnyEdge, InvocationNode } from 'features/nodes/types/invocation'; +import type { InvocationNode } from 'features/nodes/types/invocation'; +import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation'; import { assert } from 'tsafe'; -export const resolveBatchValue = async ( - batchNode: InvocationNode, - nodes: InvocationNode[], - edges: AnyEdge[], - store: AppStore -) => { - if (batchNode.data.type === 'image_batch') { - assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images)); - const ownValue = batchNode.data.inputs.images.value ?? []; +export const resolveBatchValue = async (arg: { + dispatch: AppDispatch; + nodesState: NodesState; + node: InvocationNode; +}) => { + const { node, dispatch, nodesState } = arg; + const { nodes, edges } = nodesState; + const invocationNodes = nodes.filter(isInvocationNode); + + if (node.data.type === 'image_batch') { + assert(isImageFieldCollectionInputInstance(node.data.inputs.images)); + const ownValue = node.data.inputs.images.value ?? []; // no generators for images yet return ownValue; - } else if (batchNode.data.type === 'string_batch') { - assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings)); - const ownValue = batchNode.data.inputs.strings.value; - const edgeToStrings = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'strings'); + } else if (node.data.type === 'string_batch') { + assert(isStringFieldCollectionInputInstance(node.data.inputs.strings)); + const ownValue = node.data.inputs.strings.value; + const edgeToStrings = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'strings'); if (!edgeToStrings) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === edgeToStrings.source); + const generatorNode = invocationNodes.find((node) => node.id === edgeToStrings.source); assert(generatorNode, 'Missing edge from string generator to string batch'); const generatorField = generatorNode.data.inputs['generator']; assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator'); - const generatorValue = await resolveStringGeneratorField(generatorField, store); + const generatorValue = await resolveStringGeneratorField(generatorField, dispatch); return generatorValue; - } else if (batchNode.data.type === 'float_batch') { - assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats)); - const ownValue = batchNode.data.inputs.floats.value; - const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats'); + } else if (node.data.type === 'float_batch') { + assert(isFloatFieldCollectionInputInstance(node.data.inputs.floats)); + const ownValue = node.data.inputs.floats.value; + const edgeToFloats = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'floats'); if (!edgeToFloats) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === edgeToFloats.source); + const generatorNode = invocationNodes.find((node) => node.id === edgeToFloats.source); assert(generatorNode, 'Missing edge from float generator to float batch'); const generatorField = generatorNode.data.inputs['generator']; @@ -59,16 +64,16 @@ export const resolveBatchValue = async ( const generatorValue = resolveFloatGeneratorField(generatorField); return generatorValue; - } else if (batchNode.data.type === 'integer_batch') { - assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers)); - const ownValue = batchNode.data.inputs.integers.value; - const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers'); + } else if (node.data.type === 'integer_batch') { + assert(isIntegerFieldCollectionInputInstance(node.data.inputs.integers)); + const ownValue = node.data.inputs.integers.value; + const incomers = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'integers'); if (!incomers) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === incomers.source); + const generatorNode = invocationNodes.find((node) => node.id === incomers.source); assert(generatorNode, 'Missing edge from integer generator to integer batch'); const generatorField = generatorNode.data.inputs['generator']; @@ -79,3 +84,80 @@ export const resolveBatchValue = async ( } assert(false, 'Invalid batch node type'); }; + +export type BatchSizeResult = number | 'EMPTY_BATCHES' | 'NO_BATCHES' | 'MISMATCHED_BATCH_GROUP'; + +export const getBatchSize = async (nodesState: NodesState, dispatch: AppDispatch): Promise => { + const { nodes } = nodesState; + const batchNodes = nodes.filter(isInvocationNode).filter(isBatchNode); + const ungroupedBatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'None'); + const group1BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 1'); + const group2BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 2'); + const group3BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 3'); + const group4BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 4'); + const group5BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 5'); + + const ungroupedBatchSizes = await Promise.all( + ungroupedBatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group1BatchSizes = await Promise.all( + group1BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group2BatchSizes = await Promise.all( + group2BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group3BatchSizes = await Promise.all( + group3BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group4BatchSizes = await Promise.all( + group4BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group5BatchSizes = await Promise.all( + group5BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + + // All batch nodes _must_ have a populated collection + + const allBatchSizes = [ + ...ungroupedBatchSizes, + ...group1BatchSizes, + ...group2BatchSizes, + ...group3BatchSizes, + ...group4BatchSizes, + ...group5BatchSizes, + ]; + + // There are no batch nodes + if (allBatchSizes.length === 0) { + return 'NO_BATCHES'; + } + + // All batch nodes must have a populated collection + if (allBatchSizes.some((size) => size === 0)) { + return 'EMPTY_BATCHES'; + } + + for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) { + // Ignore groups with no batch nodes + if (group.length === 0) { + continue; + } + // Grouped batch nodes must have the same collection size + if (group.some((size) => size !== group[0])) { + return 'MISMATCHED_BATCH_GROUP'; + } + } + + // Total batch size = product of all ungrouped batches and each grouped batch + const totalBatchSize = [ + ...ungroupedBatchSizes, + // In case of no batch nodes in a group, fall back to 1 for the product calculation + group1BatchSizes[0] ?? 1, + group2BatchSizes[0] ?? 1, + group3BatchSizes[0] ?? 1, + group4BatchSizes[0] ?? 1, + group5BatchSizes[0] ?? 1, + ].reduce((acc, size) => acc * size, 1); + + return totalBatchSize; +}; diff --git a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx index f59c1b443d..d5f608e97e 100644 --- a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx +++ b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx @@ -1,21 +1,21 @@ import type { TooltipProps } from '@invoke-ai/ui-library'; import { Divider, Flex, ListItem, Text, Tooltip, UnorderedList } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { selectSendToCanvas } from 'features/controlLayers/store/canvasSettingsSlice'; import { selectIterations } from 'features/controlLayers/store/paramsSlice'; import { selectDynamicPromptsIsLoading } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; +import { selectNodesSlice } from 'features/nodes/store/selectors'; +import type { NodesState } from 'features/nodes/store/types'; +import type { BatchSizeResult } from 'features/nodes/util/node/resolveBatchValue'; +import { getBatchSize } from 'features/nodes/util/node/resolveBatchValue'; import type { Reason } from 'features/queue/store/readiness'; -import { - $isReadyToEnqueue, - $reasonsWhyCannotEnqueue, - selectPromptsCount, - selectWorkflowsBatchSize, -} from 'features/queue/store/readiness'; +import { $isReadyToEnqueue, $reasonsWhyCannotEnqueue, selectPromptsCount } from 'features/queue/store/readiness'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import { debounce } from 'lodash-es'; import type { PropsWithChildren } from 'react'; -import { memo, useMemo } from 'react'; +import { memo, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue'; import { useBoardName } from 'services/api/hooks/useBoardName'; @@ -129,10 +129,27 @@ QueueCountPredictionCanvasOrUpscaleTab.displayName = 'QueueCountPredictionCanvas const QueueCountPredictionWorkflowsTab = memo(() => { const { t } = useTranslation(); - const batchSize = useAppSelector(selectWorkflowsBatchSize); + const dispatch = useAppDispatch(); + const nodesState = useAppSelector(selectNodesSlice); + const [batchSize, setBatchSize] = useState('LOADING'); + const debouncedUpdateBatchSize = useMemo( + () => + debounce(async (nodesState: NodesState) => { + setBatchSize('LOADING'); + const batchSize = await getBatchSize(nodesState, dispatch); + setBatchSize(batchSize); + }, 300), + [dispatch] + ); + useEffect(() => { + debouncedUpdateBatchSize(nodesState); + }, [debouncedUpdateBatchSize, nodesState]); const iterationsCount = useAppSelector(selectIterations); const text = useMemo(() => { + if (batchSize === 'LOADING') { + return `${t('common.loading')}...`; + } const iterations = t('queue.iterations', { count: iterationsCount }); if (batchSize === 'NO_BATCHES') { const generationCount = Math.min(10000, iterationsCount); @@ -140,7 +157,10 @@ const QueueCountPredictionWorkflowsTab = memo(() => { return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase(); } if (batchSize === 'EMPTY_BATCHES') { - return t('parameters.invoke.invalidBatchConfigurationCannotCalculate'); + return t('parameters.invoke.batchNodeEmptyCollection'); + } + if (batchSize === 'MISMATCHED_BATCH_GROUP') { + return t('parameters.invoke.batchNodeCollectionSizeMismatchNoGroupId'); } const generationCount = Math.min(batchSize * iterationsCount, 10000); const generations = t('queue.generations', { count: generationCount }); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index e0a2471702..6a92033b54 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -1,8 +1,9 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppStore } from 'app/store/nanostores/store'; import { $true } from 'app/store/nanostores/util'; +import type { AppDispatch, AppStore } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import type { AppConfig } from 'app/types/invokeai'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; @@ -80,7 +81,8 @@ const debouncedUpdateReasons = debounce( workflowSettings: WorkflowSettingsState, templates: Templates, upscale: UpscaleState, - config: AppConfig + config: AppConfig, + store: AppStore ) => { if (tab === 'canvas') { const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ @@ -96,10 +98,11 @@ const debouncedUpdateReasons = debounce( }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { - const reasons = getReasonsWhyCannotEnqueueWorkflowsTab({ + const reasons = await getReasonsWhyCannotEnqueueWorkflowsTab({ + dispatch: store.dispatch, + nodesState: nodes, + workflowSettingsState: workflowSettings, isConnected, - nodes, - workflowSettings, templates, }); $reasonsWhyCannotEnqueue.set(reasons); @@ -120,6 +123,7 @@ const debouncedUpdateReasons = debounce( export const useReadinessWatcher = () => { useAssertSingleton('useReadinessWatcher'); + const store = useAppStore(); const canvasManager = useCanvasManagerSafe(); const tab = useAppSelector(selectActiveTab); const canvas = useAppSelector(selectCanvasSlice); @@ -153,9 +157,11 @@ export const useReadinessWatcher = () => { workflowSettings, templates, upscale, - config + config, + store ); }, [ + store, canvas, canvasIsCompositing, canvasIsFiltering, @@ -176,21 +182,23 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { +const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: { + dispatch: AppDispatch; + nodesState: NodesState; + workflowSettingsState: WorkflowSettingsState; isConnected: boolean; - nodes: NodesState; - workflowSettings: WorkflowSettingsState; templates: Templates; -}): Reason[] => { - const { isConnected, nodes, workflowSettings, templates } = arg; +}): Promise => { + const { dispatch, nodesState, workflowSettingsState, isConnected, templates } = arg; const reasons: Reason[] = []; if (!isConnected) { reasons.push(disconnectedReason(i18n.t)); } - if (workflowSettings.shouldValidateGraph) { - const invocationNodes = nodes.nodes.filter(isInvocationNode); + if (workflowSettingsState.shouldValidateGraph) { + const { nodes, edges } = nodesState; + const invocationNodes = nodes.filter(isInvocationNode); const batchNodes = invocationNodes.filter(isBatchNode); const executableNodes = invocationNodes.filter(isExecutableNode); @@ -199,7 +207,7 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { } for (const node of batchNodes) { - if (nodes.edges.find((e) => e.source === node.id) === undefined) { + if (edges.find((e) => e.source === node.id) === undefined) { reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) }); } } @@ -212,7 +220,7 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { const groupBatchSizes: number[] = []; for (const node of batchNodes) { - const size = resolveBatchValue(node, invocationNodes, nodes.edges).length; + const size = (await resolveBatchValue({ dispatch, nodesState, node })).length; if (batchGroupId === 'None') { // Ungrouped batch nodes may have differing collection sizes batchSizes.push(size); @@ -237,12 +245,12 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { } } - executableNodes.forEach((node) => { + invocationNodes.forEach((node) => { if (!isInvocationNode(node)) { return; } - const errors = getInvocationNodeErrors(node.data.id, templates, nodes); + const errors = getInvocationNodeErrors(node.data.id, templates, nodesState); for (const error of errors) { if (error.type === 'node-error') { @@ -490,81 +498,3 @@ export const selectPromptsCount = createSelector( selectDynamicPromptsSlice, (params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1) ); - -const buildSelectGroupBatchSizes = (batchGroupId: string) => - createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => { - const invocationNodes = nodes.filter(isInvocationNode); - return invocationNodes - .filter(isBatchNode) - .filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId) - .map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length); - }); - -const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None'); -const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1'); -const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2'); -const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3'); -const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4'); -const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5'); - -export const selectWorkflowsBatchSize = createSelector( - selectUngroupedBatchSizes, - selectGroup1BatchSizes, - selectGroup2BatchSizes, - selectGroup3BatchSizes, - selectGroup4BatchSizes, - selectGroup5BatchSizes, - ( - ungroupedBatchSizes, - group1BatchSizes, - group2BatchSizes, - group3BatchSizes, - group4BatchSizes, - group5BatchSizes - ): number | 'EMPTY_BATCHES' | 'NO_BATCHES' => { - // All batch nodes _must_ have a populated collection - - const allBatchSizes = [ - ...ungroupedBatchSizes, - ...group1BatchSizes, - ...group2BatchSizes, - ...group3BatchSizes, - ...group4BatchSizes, - ...group5BatchSizes, - ]; - - // There are no batch nodes - if (allBatchSizes.length === 0) { - return 'NO_BATCHES'; - } - - // All batch nodes must have a populated collection - if (allBatchSizes.some((size) => size === 0)) { - return 'EMPTY_BATCHES'; - } - - for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) { - // Ignore groups with no batch nodes - if (group.length === 0) { - continue; - } - // Grouped batch nodes must have the same collection size - if (group.some((size) => size !== group[0])) { - return 'EMPTY_BATCHES'; - } - } - - // Total batch size = product of all ungrouped batches and each grouped batch - const totalBatchSize = [ - ...ungroupedBatchSizes, - // In case of no batch nodes in a group, fall back to 1 for the product calculation - group1BatchSizes[0] ?? 1, - group2BatchSizes[0] ?? 1, - group3BatchSizes[0] ?? 1, - group4BatchSizes[0] ?? 1, - group5BatchSizes[0] ?? 1, - ].reduce((acc, size) => acc * size, 1); - - return totalBatchSize; - } -); diff --git a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts index d10e8c2103..0b8839032c 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts @@ -28,5 +28,3 @@ export const utilitiesApi = api.injectEndpoints({ }), }), }); - -export const { useDynamicPromptsQuery } = utilitiesApi;