diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 6370d34f26..459bf17d6b 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1080,7 +1080,7 @@ "emptyBatches": "empty batches", "batchNodeNotConnected": "Batch node not connected: {{label}}", "batchNodeEmptyCollection": "Some batch nodes have empty collections", - "invalidBatchConfigurationCannotCalculate": "Invalid batch configuration; cannot calculate", + "collectionEmpty": "empty collection", "collectionTooFewItems": "too few items, minimum {{minItems}}", "collectionTooManyItems": "too many items, maximum {{maxItems}}", "collectionStringTooLong": "too long, max {{maxLength}}", @@ -1090,6 +1090,7 @@ "collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)", "collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)", "collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}", + "batchNodeCollectionSizeMismatchNoGroupId": "Batch group collection size mismatch", "batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}", "noModelSelected": "No model selected", "noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation", diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index fa5fca0757..6a14028ad9 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -16,13 +16,13 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = enqueueRequested.match(action) && action.payload.tabName === 'workflows', effect: async (action, { getState, dispatch }) => { const state = getState(); - const nodes = selectNodesSlice(state); + const nodesState = selectNodesSlice(state); const workflow = state.workflow; const templates = $templates.get(); const graph = buildNodesGraph(state, templates); const builtWorkflow = buildWorkflowWithValidation({ - nodes: nodes.nodes, - edges: nodes.edges, + nodes: nodesState.nodes, + edges: nodesState.edges, workflow, }); @@ -33,7 +33,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = const data: Batch['data'] = []; - const invocationNodes = nodes.nodes.filter(isInvocationNode); + const invocationNodes = nodesState.nodes.filter(isInvocationNode); const batchNodes = invocationNodes.filter(isBatchNode); // Handle zipping batch nodes. First group the batch nodes by their batch_group_id @@ -44,9 +44,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = const zippedBatchDataCollectionItems: NonNullable[number] = []; for (const node of batchNodes) { - const value = resolveBatchValue(node, invocationNodes, nodes.edges); + const value = await resolveBatchValue({ nodesState, node, dispatch }); const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value'; - const edgesFromBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle); + const edgesFromBatch = nodesState.edges.filter( + (e) => e.source === node.id && e.sourceHandle === sourceHandle + ); if (batchGroupId !== 'None') { // If this batch node has a batch_group_id, we will zip the data collection items for (const edge of edgesFromBatch) { diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx index 4e57bf87ff..2c5848043e 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorFieldComponent.tsx @@ -1,5 +1,4 @@ import { Flex, Select, Text } from '@invoke-ai/ui-library'; -import { useAppStore } from 'app/store/nanostores/store'; import { useAppDispatch } from 'app/store/storeHooks'; import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants'; import { StringGeneratorDynamicPromptsCombinatorialSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringGeneratorDynamicPromptsCombinatorialSettings'; @@ -29,7 +28,6 @@ export const StringGeneratorFieldInputComponent = memo( const { nodeId, field } = props; const { t } = useTranslation(); const dispatch = useAppDispatch(); - const store = useAppStore(); const onChange = useCallback( (value: StringGeneratorFieldInputInstance['value']) => { @@ -62,14 +60,14 @@ export const StringGeneratorFieldInputComponent = memo( const resolveAndSetValuesAsString = useMemo( () => debounce(async (field: StringGeneratorFieldInputInstance) => { - const resolvedValues = await resolveStringGeneratorField(field, store); + const resolvedValues = await resolveStringGeneratorField(field, dispatch); if (resolvedValues.length === 0) { setResolvedValuesAsString(`<${t('nodes.generatorNoValues')}>`); } else { setResolvedValuesAsString(resolvedValues.join(', ')); } }, 300), - [store, t] + [dispatch, t] ); useEffect(() => { resolveAndSetValuesAsString(field); diff --git a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts index 76d2962e4c..94fa97ecab 100644 --- a/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts +++ b/invokeai/frontend/web/src/features/nodes/store/util/fieldValidators.ts @@ -36,14 +36,14 @@ const validateImageFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } return reasons; @@ -60,14 +60,14 @@ const validateStringFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } for (const str of value) { @@ -93,14 +93,14 @@ const validateNumberFieldCollectionValue = ( // Image collections may have min or max items to validate if (minItems !== undefined && minItems > 0 && count === 0) { reasons.push(t('parameters.invoke.collectionEmpty')); - } + } else { + if (minItems !== undefined && count < minItems) { + reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); + } - if (minItems !== undefined && count < minItems) { - reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems })); - } - - if (maxItems !== undefined && count > maxItems) { - reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + if (maxItems !== undefined && count > maxItems) { + reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems })); + } } for (const num of value) { @@ -174,6 +174,8 @@ export const getFieldErrors = ( prefix, issue: t('parameters.invoke.missingInputForField'), }); + } else if (isConnected) { + // Connected fields have no value to validate - they are OK } else if ( field.value && isImageFieldCollectionInputTemplate(fieldTemplate) && diff --git a/invokeai/frontend/web/src/features/nodes/types/field.ts b/invokeai/frontend/web/src/features/nodes/types/field.ts index d405362c0f..e293b81e89 100644 --- a/invokeai/frontend/web/src/features/nodes/types/field.ts +++ b/invokeai/frontend/web/src/features/nodes/types/field.ts @@ -1,5 +1,5 @@ import { EMPTY_ARRAY } from 'app/store/constants'; -import type { AppStore } from 'app/store/store'; +import type { AppDispatch } from 'app/store/store'; import { isNil, random, trim } from 'lodash-es'; import MersenneTwister from 'mtwist'; import { utilitiesApi } from 'services/api/endpoints/utilities'; @@ -1414,10 +1414,10 @@ const getStringGeneratorDynamicPromptsCombinatorialDefaults = () => zStringGeneratorDynamicPromptsCombinatorial.parse({}); const getStringGeneratorDynamicPromptsCombinatorialValues = async ( generator: StringGeneratorDynamicPromptsCombinatorial, - store: AppStore + dispatch: AppDispatch ): Promise => { const { input, maxPrompts } = generator; - const req = store.dispatch( + const req = dispatch( utilitiesApi.endpoints.dynamicPrompts.initiate( { prompt: input, @@ -1452,10 +1452,10 @@ export type StringGeneratorDynamicPromptsRandom = z.infer zStringGeneratorDynamicPromptsRandom.parse({}); const getStringGeneratorDynamicPromptsRandomValues = async ( generator: StringGeneratorDynamicPromptsRandom, - store: AppStore + dispatch: AppDispatch ): Promise => { const { input, seed, count } = generator; - const req = store.dispatch( + const req = dispatch( utilitiesApi.endpoints.dynamicPrompts.initiate( { prompt: input, @@ -1503,7 +1503,10 @@ export const isStringGeneratorFieldInputTemplate = buildTemplateTypeGuard { +export const resolveStringGeneratorField = async ( + { value }: StringGeneratorFieldInputInstance, + dispatch: AppDispatch +) => { if (value.values) { return value.values; } @@ -1511,10 +1514,10 @@ export const resolveStringGeneratorField = async ({ value }: StringGeneratorFiel return getStringGeneratorParseStringValues(value); } if (value.type === StringGeneratorDynamicPromptsRandomType) { - return await getStringGeneratorDynamicPromptsRandomValues(value, store); + return await getStringGeneratorDynamicPromptsRandomValues(value, dispatch); } if (value.type === StringGeneratorDynamicPromptsCombinatorialType) { - return await getStringGeneratorDynamicPromptsCombinatorialValues(value, store); + return await getStringGeneratorDynamicPromptsCombinatorialValues(value, dispatch); } assert(false, 'Invalid string generator type'); }; diff --git a/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts b/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts index 52098839a2..a4cb58baa0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts +++ b/invokeai/frontend/web/src/features/nodes/util/node/resolveBatchValue.ts @@ -1,4 +1,5 @@ -import type { AppStore } from 'app/store/store'; +import type { AppDispatch } from 'app/store/store'; +import type { NodesState } from 'features/nodes/store/types'; import { isFloatFieldCollectionInputInstance, isFloatGeneratorFieldInputInstance, @@ -11,47 +12,51 @@ import { resolveIntegerGeneratorField, resolveStringGeneratorField, } from 'features/nodes/types/field'; -import type { AnyEdge, InvocationNode } from 'features/nodes/types/invocation'; +import type { InvocationNode } from 'features/nodes/types/invocation'; +import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation'; import { assert } from 'tsafe'; -export const resolveBatchValue = async ( - batchNode: InvocationNode, - nodes: InvocationNode[], - edges: AnyEdge[], - store: AppStore -) => { - if (batchNode.data.type === 'image_batch') { - assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images)); - const ownValue = batchNode.data.inputs.images.value ?? []; +export const resolveBatchValue = async (arg: { + dispatch: AppDispatch; + nodesState: NodesState; + node: InvocationNode; +}) => { + const { node, dispatch, nodesState } = arg; + const { nodes, edges } = nodesState; + const invocationNodes = nodes.filter(isInvocationNode); + + if (node.data.type === 'image_batch') { + assert(isImageFieldCollectionInputInstance(node.data.inputs.images)); + const ownValue = node.data.inputs.images.value ?? []; // no generators for images yet return ownValue; - } else if (batchNode.data.type === 'string_batch') { - assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings)); - const ownValue = batchNode.data.inputs.strings.value; - const edgeToStrings = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'strings'); + } else if (node.data.type === 'string_batch') { + assert(isStringFieldCollectionInputInstance(node.data.inputs.strings)); + const ownValue = node.data.inputs.strings.value; + const edgeToStrings = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'strings'); if (!edgeToStrings) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === edgeToStrings.source); + const generatorNode = invocationNodes.find((node) => node.id === edgeToStrings.source); assert(generatorNode, 'Missing edge from string generator to string batch'); const generatorField = generatorNode.data.inputs['generator']; assert(isStringGeneratorFieldInputInstance(generatorField), 'Invalid string generator'); - const generatorValue = await resolveStringGeneratorField(generatorField, store); + const generatorValue = await resolveStringGeneratorField(generatorField, dispatch); return generatorValue; - } else if (batchNode.data.type === 'float_batch') { - assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats)); - const ownValue = batchNode.data.inputs.floats.value; - const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats'); + } else if (node.data.type === 'float_batch') { + assert(isFloatFieldCollectionInputInstance(node.data.inputs.floats)); + const ownValue = node.data.inputs.floats.value; + const edgeToFloats = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'floats'); if (!edgeToFloats) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === edgeToFloats.source); + const generatorNode = invocationNodes.find((node) => node.id === edgeToFloats.source); assert(generatorNode, 'Missing edge from float generator to float batch'); const generatorField = generatorNode.data.inputs['generator']; @@ -59,16 +64,16 @@ export const resolveBatchValue = async ( const generatorValue = resolveFloatGeneratorField(generatorField); return generatorValue; - } else if (batchNode.data.type === 'integer_batch') { - assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers)); - const ownValue = batchNode.data.inputs.integers.value; - const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers'); + } else if (node.data.type === 'integer_batch') { + assert(isIntegerFieldCollectionInputInstance(node.data.inputs.integers)); + const ownValue = node.data.inputs.integers.value; + const incomers = edges.find((edge) => edge.target === node.id && edge.targetHandle === 'integers'); if (!incomers) { return ownValue ?? []; } - const generatorNode = nodes.find((node) => node.id === incomers.source); + const generatorNode = invocationNodes.find((node) => node.id === incomers.source); assert(generatorNode, 'Missing edge from integer generator to integer batch'); const generatorField = generatorNode.data.inputs['generator']; @@ -79,3 +84,80 @@ export const resolveBatchValue = async ( } assert(false, 'Invalid batch node type'); }; + +export type BatchSizeResult = number | 'EMPTY_BATCHES' | 'NO_BATCHES' | 'MISMATCHED_BATCH_GROUP'; + +export const getBatchSize = async (nodesState: NodesState, dispatch: AppDispatch): Promise => { + const { nodes } = nodesState; + const batchNodes = nodes.filter(isInvocationNode).filter(isBatchNode); + const ungroupedBatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'None'); + const group1BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 1'); + const group2BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 2'); + const group3BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 3'); + const group4BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 4'); + const group5BatchNodes = batchNodes.filter((node) => node.data.inputs['batch_group_id']?.value === 'Group 5'); + + const ungroupedBatchSizes = await Promise.all( + ungroupedBatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group1BatchSizes = await Promise.all( + group1BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group2BatchSizes = await Promise.all( + group2BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group3BatchSizes = await Promise.all( + group3BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group4BatchSizes = await Promise.all( + group4BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + const group5BatchSizes = await Promise.all( + group5BatchNodes.map(async (node) => (await resolveBatchValue({ nodesState, dispatch, node })).length) + ); + + // All batch nodes _must_ have a populated collection + + const allBatchSizes = [ + ...ungroupedBatchSizes, + ...group1BatchSizes, + ...group2BatchSizes, + ...group3BatchSizes, + ...group4BatchSizes, + ...group5BatchSizes, + ]; + + // There are no batch nodes + if (allBatchSizes.length === 0) { + return 'NO_BATCHES'; + } + + // All batch nodes must have a populated collection + if (allBatchSizes.some((size) => size === 0)) { + return 'EMPTY_BATCHES'; + } + + for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) { + // Ignore groups with no batch nodes + if (group.length === 0) { + continue; + } + // Grouped batch nodes must have the same collection size + if (group.some((size) => size !== group[0])) { + return 'MISMATCHED_BATCH_GROUP'; + } + } + + // Total batch size = product of all ungrouped batches and each grouped batch + const totalBatchSize = [ + ...ungroupedBatchSizes, + // In case of no batch nodes in a group, fall back to 1 for the product calculation + group1BatchSizes[0] ?? 1, + group2BatchSizes[0] ?? 1, + group3BatchSizes[0] ?? 1, + group4BatchSizes[0] ?? 1, + group5BatchSizes[0] ?? 1, + ].reduce((acc, size) => acc * size, 1); + + return totalBatchSize; +}; diff --git a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx index f59c1b443d..d5f608e97e 100644 --- a/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx +++ b/invokeai/frontend/web/src/features/queue/components/InvokeButtonTooltip/InvokeButtonTooltip.tsx @@ -1,21 +1,21 @@ import type { TooltipProps } from '@invoke-ai/ui-library'; import { Divider, Flex, ListItem, Text, Tooltip, UnorderedList } from '@invoke-ai/ui-library'; import { useStore } from '@nanostores/react'; -import { useAppSelector } from 'app/store/storeHooks'; +import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { selectSendToCanvas } from 'features/controlLayers/store/canvasSettingsSlice'; import { selectIterations } from 'features/controlLayers/store/paramsSlice'; import { selectDynamicPromptsIsLoading } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectAutoAddBoardId } from 'features/gallery/store/gallerySelectors'; +import { selectNodesSlice } from 'features/nodes/store/selectors'; +import type { NodesState } from 'features/nodes/store/types'; +import type { BatchSizeResult } from 'features/nodes/util/node/resolveBatchValue'; +import { getBatchSize } from 'features/nodes/util/node/resolveBatchValue'; import type { Reason } from 'features/queue/store/readiness'; -import { - $isReadyToEnqueue, - $reasonsWhyCannotEnqueue, - selectPromptsCount, - selectWorkflowsBatchSize, -} from 'features/queue/store/readiness'; +import { $isReadyToEnqueue, $reasonsWhyCannotEnqueue, selectPromptsCount } from 'features/queue/store/readiness'; import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import { debounce } from 'lodash-es'; import type { PropsWithChildren } from 'react'; -import { memo, useMemo } from 'react'; +import { memo, useEffect, useMemo, useState } from 'react'; import { useTranslation } from 'react-i18next'; import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue'; import { useBoardName } from 'services/api/hooks/useBoardName'; @@ -129,10 +129,27 @@ QueueCountPredictionCanvasOrUpscaleTab.displayName = 'QueueCountPredictionCanvas const QueueCountPredictionWorkflowsTab = memo(() => { const { t } = useTranslation(); - const batchSize = useAppSelector(selectWorkflowsBatchSize); + const dispatch = useAppDispatch(); + const nodesState = useAppSelector(selectNodesSlice); + const [batchSize, setBatchSize] = useState('LOADING'); + const debouncedUpdateBatchSize = useMemo( + () => + debounce(async (nodesState: NodesState) => { + setBatchSize('LOADING'); + const batchSize = await getBatchSize(nodesState, dispatch); + setBatchSize(batchSize); + }, 300), + [dispatch] + ); + useEffect(() => { + debouncedUpdateBatchSize(nodesState); + }, [debouncedUpdateBatchSize, nodesState]); const iterationsCount = useAppSelector(selectIterations); const text = useMemo(() => { + if (batchSize === 'LOADING') { + return `${t('common.loading')}...`; + } const iterations = t('queue.iterations', { count: iterationsCount }); if (batchSize === 'NO_BATCHES') { const generationCount = Math.min(10000, iterationsCount); @@ -140,7 +157,10 @@ const QueueCountPredictionWorkflowsTab = memo(() => { return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase(); } if (batchSize === 'EMPTY_BATCHES') { - return t('parameters.invoke.invalidBatchConfigurationCannotCalculate'); + return t('parameters.invoke.batchNodeEmptyCollection'); + } + if (batchSize === 'MISMATCHED_BATCH_GROUP') { + return t('parameters.invoke.batchNodeCollectionSizeMismatchNoGroupId'); } const generationCount = Math.min(batchSize * iterationsCount, 10000); const generations = t('queue.generations', { count: generationCount }); diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index e0a2471702..6a92033b54 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -1,8 +1,9 @@ import { useStore } from '@nanostores/react'; import { createSelector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; -import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; +import { useAppStore } from 'app/store/nanostores/store'; import { $true } from 'app/store/nanostores/util'; +import type { AppDispatch, AppStore } from 'app/store/store'; import { useAppSelector } from 'app/store/storeHooks'; import type { AppConfig } from 'app/types/invokeai'; import { useAssertSingleton } from 'common/hooks/useAssertSingleton'; @@ -80,7 +81,8 @@ const debouncedUpdateReasons = debounce( workflowSettings: WorkflowSettingsState, templates: Templates, upscale: UpscaleState, - config: AppConfig + config: AppConfig, + store: AppStore ) => { if (tab === 'canvas') { const reasons = await getReasonsWhyCannotEnqueueCanvasTab({ @@ -96,10 +98,11 @@ const debouncedUpdateReasons = debounce( }); $reasonsWhyCannotEnqueue.set(reasons); } else if (tab === 'workflows') { - const reasons = getReasonsWhyCannotEnqueueWorkflowsTab({ + const reasons = await getReasonsWhyCannotEnqueueWorkflowsTab({ + dispatch: store.dispatch, + nodesState: nodes, + workflowSettingsState: workflowSettings, isConnected, - nodes, - workflowSettings, templates, }); $reasonsWhyCannotEnqueue.set(reasons); @@ -120,6 +123,7 @@ const debouncedUpdateReasons = debounce( export const useReadinessWatcher = () => { useAssertSingleton('useReadinessWatcher'); + const store = useAppStore(); const canvasManager = useCanvasManagerSafe(); const tab = useAppSelector(selectActiveTab); const canvas = useAppSelector(selectCanvasSlice); @@ -153,9 +157,11 @@ export const useReadinessWatcher = () => { workflowSettings, templates, upscale, - config + config, + store ); }, [ + store, canvas, canvasIsCompositing, canvasIsFiltering, @@ -176,21 +182,23 @@ export const useReadinessWatcher = () => { const disconnectedReason = (t: typeof i18n.t) => ({ content: t('parameters.invoke.systemDisconnected') }); -const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { +const getReasonsWhyCannotEnqueueWorkflowsTab = async (arg: { + dispatch: AppDispatch; + nodesState: NodesState; + workflowSettingsState: WorkflowSettingsState; isConnected: boolean; - nodes: NodesState; - workflowSettings: WorkflowSettingsState; templates: Templates; -}): Reason[] => { - const { isConnected, nodes, workflowSettings, templates } = arg; +}): Promise => { + const { dispatch, nodesState, workflowSettingsState, isConnected, templates } = arg; const reasons: Reason[] = []; if (!isConnected) { reasons.push(disconnectedReason(i18n.t)); } - if (workflowSettings.shouldValidateGraph) { - const invocationNodes = nodes.nodes.filter(isInvocationNode); + if (workflowSettingsState.shouldValidateGraph) { + const { nodes, edges } = nodesState; + const invocationNodes = nodes.filter(isInvocationNode); const batchNodes = invocationNodes.filter(isBatchNode); const executableNodes = invocationNodes.filter(isExecutableNode); @@ -199,7 +207,7 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { } for (const node of batchNodes) { - if (nodes.edges.find((e) => e.source === node.id) === undefined) { + if (edges.find((e) => e.source === node.id) === undefined) { reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) }); } } @@ -212,7 +220,7 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { const groupBatchSizes: number[] = []; for (const node of batchNodes) { - const size = resolveBatchValue(node, invocationNodes, nodes.edges).length; + const size = (await resolveBatchValue({ dispatch, nodesState, node })).length; if (batchGroupId === 'None') { // Ungrouped batch nodes may have differing collection sizes batchSizes.push(size); @@ -237,12 +245,12 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: { } } - executableNodes.forEach((node) => { + invocationNodes.forEach((node) => { if (!isInvocationNode(node)) { return; } - const errors = getInvocationNodeErrors(node.data.id, templates, nodes); + const errors = getInvocationNodeErrors(node.data.id, templates, nodesState); for (const error of errors) { if (error.type === 'node-error') { @@ -490,81 +498,3 @@ export const selectPromptsCount = createSelector( selectDynamicPromptsSlice, (params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1) ); - -const buildSelectGroupBatchSizes = (batchGroupId: string) => - createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => { - const invocationNodes = nodes.filter(isInvocationNode); - return invocationNodes - .filter(isBatchNode) - .filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId) - .map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length); - }); - -const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None'); -const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1'); -const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2'); -const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3'); -const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4'); -const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5'); - -export const selectWorkflowsBatchSize = createSelector( - selectUngroupedBatchSizes, - selectGroup1BatchSizes, - selectGroup2BatchSizes, - selectGroup3BatchSizes, - selectGroup4BatchSizes, - selectGroup5BatchSizes, - ( - ungroupedBatchSizes, - group1BatchSizes, - group2BatchSizes, - group3BatchSizes, - group4BatchSizes, - group5BatchSizes - ): number | 'EMPTY_BATCHES' | 'NO_BATCHES' => { - // All batch nodes _must_ have a populated collection - - const allBatchSizes = [ - ...ungroupedBatchSizes, - ...group1BatchSizes, - ...group2BatchSizes, - ...group3BatchSizes, - ...group4BatchSizes, - ...group5BatchSizes, - ]; - - // There are no batch nodes - if (allBatchSizes.length === 0) { - return 'NO_BATCHES'; - } - - // All batch nodes must have a populated collection - if (allBatchSizes.some((size) => size === 0)) { - return 'EMPTY_BATCHES'; - } - - for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) { - // Ignore groups with no batch nodes - if (group.length === 0) { - continue; - } - // Grouped batch nodes must have the same collection size - if (group.some((size) => size !== group[0])) { - return 'EMPTY_BATCHES'; - } - } - - // Total batch size = product of all ungrouped batches and each grouped batch - const totalBatchSize = [ - ...ungroupedBatchSizes, - // In case of no batch nodes in a group, fall back to 1 for the product calculation - group1BatchSizes[0] ?? 1, - group2BatchSizes[0] ?? 1, - group3BatchSizes[0] ?? 1, - group4BatchSizes[0] ?? 1, - group5BatchSizes[0] ?? 1, - ].reduce((acc, size) => acc * size, 1); - - return totalBatchSize; - } -); diff --git a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts index d10e8c2103..0b8839032c 100644 --- a/invokeai/frontend/web/src/services/api/endpoints/utilities.ts +++ b/invokeai/frontend/web/src/services/api/endpoints/utilities.ts @@ -28,5 +28,3 @@ export const utilitiesApi = api.injectEndpoints({ }), }), }); - -export const { useDynamicPromptsQuery } = utilitiesApi;