mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support integer generators
This commit is contained in:
@@ -27,9 +27,11 @@ import {
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
resolveFloatGeneratorField,
|
||||
resolveIntegerGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
@@ -81,19 +83,36 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
const ownValue = batchNode.data.inputs.floats.value;
|
||||
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
|
||||
|
||||
if (!edgeToFloats) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
const generator = nodes.find((node) => node.id === edgeToFloats.source);
|
||||
assert(generator, 'Missing edge from float generator to float batch');
|
||||
assert(isFloatGeneratorFieldInputInstance(generator.data.inputs['generator']), 'Invalid float generator');
|
||||
const generatorValue = resolveFloatGeneratorField(generator.data.inputs['generator']);
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
|
||||
assert(generatorNode, 'Missing edge from float generator to float batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
|
||||
|
||||
const generatorValue = resolveFloatGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
const ownValue = batchNode.data.inputs.integers.value;
|
||||
// no generators for integers yet
|
||||
return ownValue ?? [];
|
||||
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
|
||||
|
||||
if (!incomers) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === incomers.source);
|
||||
assert(generatorNode, 'Missing edge from integer generator to integer batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
|
||||
|
||||
const generatorValue = resolveIntegerGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
}
|
||||
assert(false, 'Invalid batch node type');
|
||||
};
|
||||
@@ -569,30 +588,13 @@ export const selectPromptsCount = createSelector(
|
||||
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
|
||||
);
|
||||
|
||||
const getBatchCollectionSize = (batchNode: InvocationNode) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
return batchNode.data.inputs.images.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
|
||||
return batchNode.data.inputs.strings.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
return batchNode.data.inputs.floats.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
return batchNode.data.inputs.integers.value?.length ?? 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
|
||||
return nodes
|
||||
.filter(isInvocationNode)
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
|
||||
const invocationNodes = nodes.filter(isInvocationNode);
|
||||
return invocationNodes
|
||||
.filter(isBatchNode)
|
||||
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
|
||||
.map(getBatchCollectionSize);
|
||||
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
|
||||
});
|
||||
|
||||
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
|
||||
|
||||
Reference in New Issue
Block a user