mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
fix(ui): handle batch group ids of "None" correctly
This commit is contained in:
@@ -98,7 +98,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
if (batchGroupId) {
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
@@ -117,7 +117,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
if (batchGroupId) {
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
@@ -140,7 +140,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
if (batchGroupId) {
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
|
||||
@@ -163,7 +163,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
if (batchGroupId) {
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
|
||||
@@ -171,7 +171,7 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
}
|
||||
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId && zippedBatchDataCollectionItems.length > 0) {
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { AppConfig } from 'app/types/invokeai';
|
||||
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -33,6 +34,7 @@ import {
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
@@ -87,12 +89,12 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
if (batchNodes.length > 1) {
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
if (!batchGroupId) {
|
||||
// No batch group ID implies ungrouped batch nodes, their collection sizes can be different
|
||||
if (batchGroupId === 'None') {
|
||||
// Ungrouped batch nodes may have differing collection sizes
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, all batch nodes with the same link ID must have the same collection size
|
||||
// But grouped batch nodes must have the same collection size
|
||||
const collectionSizes: number[] = [];
|
||||
|
||||
for (const node of batchNodes) {
|
||||
@@ -538,17 +540,97 @@ export const selectPromptsCount = createSelector(
|
||||
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
|
||||
);
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
|
||||
// The batch size is the product of all batch nodes' collection sizes
|
||||
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
|
||||
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
|
||||
return batchSize;
|
||||
}
|
||||
// If the batch size is not set, default to 1
|
||||
batchSize = batchSize || 1;
|
||||
// Multiply the batch size by the number of images in the batch
|
||||
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
|
||||
const getBatchCollectionSize = (batchNode: InvocationNode) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
return batchNode.data.inputs.images.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
|
||||
return batchNode.data.inputs.strings.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
return batchNode.data.inputs.floats.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
return batchNode.data.inputs.integers.value?.length ?? 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
return batchSize;
|
||||
}, 0)
|
||||
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
|
||||
return nodes
|
||||
.filter(isInvocationNode)
|
||||
.filter(isBatchNode)
|
||||
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
|
||||
.map(getBatchCollectionSize);
|
||||
});
|
||||
|
||||
export const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
|
||||
export const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
|
||||
export const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
|
||||
export const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
|
||||
export const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
|
||||
export const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(
|
||||
selectUngroupedBatchSizes,
|
||||
selectGroup1BatchSizes,
|
||||
selectGroup2BatchSizes,
|
||||
selectGroup3BatchSizes,
|
||||
selectGroup4BatchSizes,
|
||||
selectGroup5BatchSizes,
|
||||
(
|
||||
ungroupedBatchSizes,
|
||||
group1BatchSizes,
|
||||
group2BatchSizes,
|
||||
group3BatchSizes,
|
||||
group4BatchSizes,
|
||||
group5BatchSizes
|
||||
): number | 'INVALID' | 'NO_BATCHES' => {
|
||||
// All batch nodes _must_ have a populated collection
|
||||
|
||||
const allBatchSizes = [
|
||||
...ungroupedBatchSizes,
|
||||
...group1BatchSizes,
|
||||
...group2BatchSizes,
|
||||
...group3BatchSizes,
|
||||
...group4BatchSizes,
|
||||
...group5BatchSizes,
|
||||
];
|
||||
|
||||
// There are no batch nodes
|
||||
if (allBatchSizes.length === 0) {
|
||||
return 'NO_BATCHES';
|
||||
}
|
||||
|
||||
// All batch nodes must have a populated collection
|
||||
if (allBatchSizes.some((size) => size === 0)) {
|
||||
return 'INVALID';
|
||||
}
|
||||
|
||||
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
|
||||
// Ignore groups with no batch nodes
|
||||
if (group.length === 0) {
|
||||
continue;
|
||||
}
|
||||
// Grouped batch nodes must have the same collection size
|
||||
if (group.some((size) => size !== group[0])) {
|
||||
return 'INVALID';
|
||||
}
|
||||
}
|
||||
|
||||
// Total batch size = product of all ungrouped batches and each grouped batch
|
||||
const totalBatchSize = [
|
||||
...ungroupedBatchSizes,
|
||||
// In case of no batch nodes in a group, fall back to 1 for the product calculation
|
||||
group1BatchSizes[0] ?? 1,
|
||||
group2BatchSizes[0] ?? 1,
|
||||
group3BatchSizes[0] ?? 1,
|
||||
group4BatchSizes[0] ?? 1,
|
||||
group5BatchSizes[0] ?? 1,
|
||||
].reduce((acc, size) => acc * size, 1);
|
||||
|
||||
return totalBatchSize;
|
||||
}
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user