feat(ui): add zipped batch collection size validation

This commit is contained in:
psychedelicious
2025-01-14 11:55:47 +11:00
parent 0abc0be931
commit 803ec8e904
2 changed files with 41 additions and 4 deletions

View File

@@ -38,8 +38,9 @@ import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import i18n from 'i18next';
import { forEach, negate, upperFirst } from 'lodash-es';
import { forEach, groupBy, negate, upperFirst } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
import { assert } from 'tsafe';
/**
* This file contains selectors and utilities for determining the app is ready to enqueue generations. The handling
@@ -75,13 +76,48 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
}
if (workflowSettings.shouldValidateGraph) {
const nodesToCheck = nodes.nodes.filter(isInvocationNode).filter(negate(isBatchNode));
const invocationNodes = nodes.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
const nonBatchNodes = invocationNodes.filter(negate(isBatchNode));
if (!nodesToCheck.length) {
if (!nonBatchNodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodesToCheck.forEach((node) => {
if (batchNodes.length > 1) {
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['link_id']?.value);
for (const [linkId, batchNodes] of Object.entries(groupedBatchNodes)) {
if (!linkId) {
// No link ID implies ungrouped batch nodes, their collection sizes can be different
continue;
}
// Otherwise, all batch nodes with the same link ID must have the same collection size
const collectionSizes: number[] = [];
for (const node of batchNodes) {
if (node.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(node.data.inputs.images));
collectionSizes.push(node.data.inputs.images.value?.length ?? 0);
} else if (node.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(node.data.inputs.strings));
collectionSizes.push(node.data.inputs.strings.value?.length ?? 0);
} else if (node.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(node.data.inputs.floats));
collectionSizes.push(node.data.inputs.floats.value?.length ?? 0);
} else if (node.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(node.data.inputs.integers));
collectionSizes.push(node.data.inputs.integers.value?.length ?? 0);
}
}
if (collectionSizes.some((count) => count !== collectionSizes[0])) {
reasons.push({ content: i18n.t('parameters.invoke.batchNodeCollectionSizeMismatch', { linkId }) });
}
}
}
nonBatchNodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}