fix(ui): handle batch group ids of "None" correctly

This commit is contained in:
psychedelicious
2025-01-14 16:11:31 +11:00
parent 3df3be6c34
commit 30e33d30d5
2 changed files with 102 additions and 20 deletions

View File

@@ -98,7 +98,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
if (batchGroupId) {
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
} else {
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
@@ -117,7 +117,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
if (batchGroupId) {
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
@@ -140,7 +140,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
if (batchGroupId) {
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
@@ -163,7 +163,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
if (batchGroupId) {
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
@@ -171,7 +171,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId && zippedBatchDataCollectionItems.length > 0) {
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
}
}

View File

@@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { AppConfig } from 'app/types/invokeai';
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
@@ -33,6 +34,7 @@ import {
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import type { InvocationNode } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
@@ -87,12 +89,12 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
if (batchNodes.length > 1) {
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
if (!batchGroupId) {
// No batch group ID implies ungrouped batch nodes, their collection sizes can be different
if (batchGroupId === 'None') {
// Ungrouped batch nodes may have differing collection sizes
continue;
}
// Otherwise, all batch nodes with the same link ID must have the same collection size
// But grouped batch nodes must have the same collection size
const collectionSizes: number[] = [];
for (const node of batchNodes) {
@@ -538,17 +540,97 @@ export const selectPromptsCount = createSelector(
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
);
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
// The batch size is the product of all batch nodes' collection sizes
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
return batchSize;
}
// If the batch size is not set, default to 1
batchSize = batchSize || 1;
// Multiply the batch size by the number of images in the batch
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
const getBatchCollectionSize = (batchNode: InvocationNode) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
return batchNode.data.inputs.images.value?.length ?? 0;
} else if (batchNode.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
return batchNode.data.inputs.floats.value?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
return batchNode.data.inputs.integers.value?.length ?? 0;
}
return 0;
};
return batchSize;
}, 0)
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
return nodes
.filter(isInvocationNode)
.filter(isBatchNode)
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
.map(getBatchCollectionSize);
});
export const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
export const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
export const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
export const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
export const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
export const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
export const selectWorkflowsBatchSize = createSelector(
selectUngroupedBatchSizes,
selectGroup1BatchSizes,
selectGroup2BatchSizes,
selectGroup3BatchSizes,
selectGroup4BatchSizes,
selectGroup5BatchSizes,
(
ungroupedBatchSizes,
group1BatchSizes,
group2BatchSizes,
group3BatchSizes,
group4BatchSizes,
group5BatchSizes
): number | 'INVALID' | 'NO_BATCHES' => {
// All batch nodes _must_ have a populated collection
const allBatchSizes = [
...ungroupedBatchSizes,
...group1BatchSizes,
...group2BatchSizes,
...group3BatchSizes,
...group4BatchSizes,
...group5BatchSizes,
];
// There are no batch nodes
if (allBatchSizes.length === 0) {
return 'NO_BATCHES';
}
// All batch nodes must have a populated collection
if (allBatchSizes.some((size) => size === 0)) {
return 'INVALID';
}
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
// Ignore groups with no batch nodes
if (group.length === 0) {
continue;
}
// Grouped batch nodes must have the same collection size
if (group.some((size) => size !== group[0])) {
return 'INVALID';
}
}
// Total batch size = product of all ungrouped batches and each grouped batch
const totalBatchSize = [
...ungroupedBatchSizes,
// In case of no batch nodes in a group, fall back to 1 for the product calculation
group1BatchSizes[0] ?? 1,
group2BatchSizes[0] ?? 1,
group3BatchSizes[0] ?? 1,
group4BatchSizes[0] ?? 1,
group5BatchSizes[0] ?? 1,
].reduce((acc, size) => acc * size, 1);
return totalBatchSize;
}
);