feat(ui): support integer generators

This commit is contained in:
psychedelicious
2025-01-16 23:19:27 +11:00
parent f345fde512
commit d69e90ca5e
13 changed files with 469 additions and 35 deletions

View File

@@ -27,9 +27,11 @@ import {
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerGeneratorFieldInputInstance,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
resolveFloatGeneratorField,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
@@ -81,19 +83,36 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
const ownValue = batchNode.data.inputs.floats.value;
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
if (!edgeToFloats) {
return ownValue ?? [];
}
const generator = nodes.find((node) => node.id === edgeToFloats.source);
assert(generator, 'Missing edge from float generator to float batch');
assert(isFloatGeneratorFieldInputInstance(generator.data.inputs['generator']), 'Invalid float generator');
const generatorValue = resolveFloatGeneratorField(generator.data.inputs['generator']);
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
assert(generatorNode, 'Missing edge from float generator to float batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
const generatorValue = resolveFloatGeneratorField(generatorField);
return generatorValue;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
const ownValue = batchNode.data.inputs.integers.value;
// no generators for integers yet
return ownValue ?? [];
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
if (!incomers) {
return ownValue ?? [];
}
const generatorNode = nodes.find((node) => node.id === incomers.source);
assert(generatorNode, 'Missing edge from integer generator to integer batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
const generatorValue = resolveIntegerGeneratorField(generatorField);
return generatorValue;
}
assert(false, 'Invalid batch node type');
};
@@ -569,30 +588,13 @@ export const selectPromptsCount = createSelector(
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
);
const getBatchCollectionSize = (batchNode: InvocationNode) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
return batchNode.data.inputs.images.value?.length ?? 0;
} else if (batchNode.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
return batchNode.data.inputs.floats.value?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
return batchNode.data.inputs.integers.value?.length ?? 0;
}
return 0;
};
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
return nodes
.filter(isInvocationNode)
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
const invocationNodes = nodes.filter(isInvocationNode);
return invocationNodes
.filter(isBatchNode)
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
.map(getBatchCollectionSize);
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
});
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');