feat(ui): support integer generators

This commit is contained in:
psychedelicious
2025-01-16 23:19:27 +11:00
parent f345fde512
commit d69e90ca5e
13 changed files with 469 additions and 35 deletions

View File

@@ -857,9 +857,9 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"floatGeneratorStartEndStep": "Start, End, Step",
"floatGeneratorStartCountStep": "Start, Count Step",
"floatGeneratorRandom": "Random",
"startEndStep": "Start, End, Step",
"startCountStep": "Start, Count Step",
"random": "Random",
"noBatchGroup": "no group",
"addNode": "Add Node",
"addNodeToolTip": "Add Node (Shift+A, Space)",

View File

@@ -1,5 +1,6 @@
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
@@ -42,6 +43,8 @@ import {
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIntegerGeneratorFieldInputInstance,
isIntegerGeneratorFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
isIPAdapterModelFieldInputTemplate,
isLoRAModelFieldInputInstance,
@@ -244,6 +247,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <FloatGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerGeneratorFieldInputInstance(fieldInstance) && isIntegerGeneratorFieldInputTemplate(fieldTemplate)) {
return <IntegerGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (fieldTemplate) {
// Fallback for when there is no component for the type
return null;

View File

@@ -86,9 +86,9 @@ export const FloatGeneratorFieldInputComponent = memo(
<Flex flexDir="column" gap={2}>
<Flex gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value="float_generator_start_end_step">{t('nodes.floatGeneratorStartEndStep')}</option>
<option value="float_generator_start_count_step">{t('nodes.floatGeneratorStartCountStep')}</option>
<option value="float_generator_random">{t('nodes.floatGeneratorRandom')}</option>
<option value="float_generator_start_end_step">{t('nodes.startEndStep')}</option>
<option value="float_generator_start_count_step">{t('nodes.startCountStep')}</option>
<option value="float_generator_random">{t('nodes.random')}</option>
</Select>
<IconButton aria-label="Reset" icon={<PiPencilSimpleFill />} variant="ghost" />
</Flex>

View File

@@ -0,0 +1,126 @@
import { Flex, IconButton, Select, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { IntegerGeneratorRandomSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorRandomSettings';
import { IntegerGeneratorStartCountStepSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorStartCountStepSettings';
import { IntegerGeneratorStartEndStepSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorStartEndStepSettings';
import { fieldIntegerGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
import {
getIntegerGeneratorRandomDefaults,
getIntegerGeneratorStartCountStepDefaults,
getIntegerGeneratorStartEndStepDefaults,
type IntegerGeneratorFieldInputInstance,
type IntegerGeneratorFieldInputTemplate,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import { round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPencilSimpleFill } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const getDefaultValue = (generatorType: string) => {
if (generatorType === 'integer_generator_start_end_step') {
return getIntegerGeneratorStartEndStepDefaults();
}
if (generatorType === 'integer_generator_start_count_step') {
return getIntegerGeneratorStartCountStepDefaults();
}
if (generatorType === 'integer_generator_random') {
return getIntegerGeneratorRandomDefaults();
}
return null;
};
export const IntegerGeneratorFieldInputComponent = memo(
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
const { nodeId, field } = props;
const { t } = useTranslation();
const dispatch = useAppDispatch();
const onChange = useCallback(
(value: IntegerGeneratorFieldInputInstance['value']) => {
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const onChangeGeneratorType = useCallback(
(e: ChangeEvent<HTMLSelectElement>) => {
const value = getDefaultValue(e.target.value);
if (!value) {
return;
}
dispatch(
fieldIntegerGeneratorValueChanged({
nodeId,
fieldName: field.name,
value,
})
);
},
[dispatch, field.name, nodeId]
);
const resolvedValues = useMemo(() => resolveIntegerGeneratorField(field), [field]);
const resolvedValuesAsString = useMemo(() => {
if (resolvedValues.length === 0) {
return '<empty>';
} else {
return resolvedValues.map((val) => round(val, 2)).join(', ');
}
}, [resolvedValues]);
return (
<Flex flexDir="column" gap={2}>
<Flex gap={2}>
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
<option value="integer_generator_start_end_step">{t('nodes.startEndStep')}</option>
<option value="integer_generator_start_count_step">{t('nodes.startCountStep')}</option>
<option value="integer_generator_random">{t('nodes.random')}</option>
</Select>
<IconButton aria-label="Reset" icon={<PiPencilSimpleFill />} variant="ghost" />
</Flex>
{field.value.type === 'integer_generator_start_end_step' && (
<IntegerGeneratorStartEndStepSettings state={field.value} onChange={onChange} />
)}
{field.value.type === 'integer_generator_start_count_step' && (
<IntegerGeneratorStartCountStepSettings state={field.value} onChange={onChange} />
)}
{field.value.type === 'integer_generator_random' && (
<IntegerGeneratorRandomSettings state={field.value} onChange={onChange} />
)}
{/* We don't show previews for random generators, bc they are non-deterministic */}
{field.value.type !== 'integer_generator_random' && (
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
<Flex w="full" h="auto">
<OverlayScrollbarsComponent
className="nodrag nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
{resolvedValuesAsString}
</Text>
</OverlayScrollbarsComponent>
</Flex>
</Flex>
)}
</Flex>
);
}
);
IntegerGeneratorFieldInputComponent.displayName = 'IntegerGeneratorFieldInputComponent';

View File

@@ -0,0 +1,55 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import type { IntegerGeneratorRandom } from 'features/nodes/types/field';
import { getIntegerGeneratorRandomDefaults } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type IntegerGeneratorRandomSettingsProps = {
state: IntegerGeneratorRandom;
onChange: (state: IntegerGeneratorRandom) => void;
};
export const IntegerGeneratorRandomSettings = memo(({ state, onChange }: IntegerGeneratorRandomSettingsProps) => {
const { t } = useTranslation();
const onChangeMin = useCallback(
(min: number) => {
onChange({ ...state, min });
},
[onChange, state]
);
const onChangeMax = useCallback(
(max: number) => {
onChange({ ...state, max });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getIntegerGeneratorRandomDefaults());
}, [onChange]);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.min')}</FormLabel>
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.max')}</FormLabel>
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
</Flex>
);
});
IntegerGeneratorRandomSettings.displayName = 'IntegerGeneratorRandomSettings';

View File

@@ -0,0 +1,59 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import {
getIntegerGeneratorStartCountStepDefaults,
type IntegerGeneratorStartCountStep,
} from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type IntegerGeneratorStartCountStepSettingsProps = {
state: IntegerGeneratorStartCountStep;
onChange: (state: IntegerGeneratorStartCountStep) => void;
};
export const IntegerGeneratorStartCountStepSettings = memo(
({ state, onChange }: IntegerGeneratorStartCountStepSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getIntegerGeneratorStartCountStepDefaults());
}, [onChange]);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
</FormControl>
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
</Flex>
);
}
);
IntegerGeneratorStartCountStepSettings.displayName = 'IntegerGeneratorStartCountStepSettings';

View File

@@ -0,0 +1,56 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import { getIntegerGeneratorStartEndStepDefaults, type IntegerGeneratorStartEndStep } from 'features/nodes/types/field';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type IntegerGeneratorStartEndStepSettingsProps = {
state: IntegerGeneratorStartEndStep;
onChange: (state: IntegerGeneratorStartEndStep) => void;
};
export const IntegerGeneratorStartEndStepSettings = memo(
({ state, onChange }: IntegerGeneratorStartEndStepSettingsProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeEnd = useCallback(
(end: number) => {
onChange({ ...state, end });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getIntegerGeneratorStartEndStepDefaults());
}, [onChange]);
return (
<Flex gap={2} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.end')}</FormLabel>
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
</FormControl>
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
</Flex>
);
}
);
IntegerGeneratorStartEndStepSettings.displayName = 'IntegerGeneratorStartEndStepSettings';

View File

@@ -22,6 +22,7 @@ import type {
ImageFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IntegerGeneratorFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
MainModelFieldValue,
@@ -54,6 +55,7 @@ import {
zImageFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIntegerGeneratorFieldValue,
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
zMainModelFieldValue,
@@ -399,6 +401,9 @@ export const nodesSlice = createSlice({
fieldFloatGeneratorValueChanged: (state, action: FieldValueAction<FloatGeneratorFieldValue>) => {
fieldValueReducer(state, action, zFloatGeneratorFieldValue);
},
fieldIntegerGeneratorValueChanged: (state, action: FieldValueAction<IntegerGeneratorFieldValue>) => {
fieldValueReducer(state, action, zIntegerGeneratorFieldValue);
},
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
const { nodeId, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
@@ -464,6 +469,7 @@ export const {
fieldControlLoRAModelValueChanged,
fieldFluxVAEModelValueChanged,
fieldFloatGeneratorValueChanged,
fieldIntegerGeneratorValueChanged,
nodeEditorReset,
nodeIsIntermediateChanged,
nodeIsOpenChanged,

View File

@@ -216,6 +216,10 @@ const zFloatGeneratorFieldType = zFieldTypeBase.extend({
name: z.literal('FloatGeneratorField'),
originalType: zStatelessFieldType.optional(),
});
const zIntegerGeneratorFieldType = zFieldTypeBase.extend({
name: z.literal('IntegerGeneratorField'),
originalType: zStatelessFieldType.optional(),
});
const zStatefulFieldType = z.union([
zIntegerFieldType,
zFloatFieldType,
@@ -245,6 +249,7 @@ const zStatefulFieldType = z.union([
zColorFieldType,
zSchedulerFieldType,
zFloatGeneratorFieldType,
zIntegerGeneratorFieldType,
]);
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
@@ -1093,6 +1098,93 @@ export const resolveFloatGeneratorField = ({ value }: FloatGeneratorFieldInputIn
};
// #endregion
// #region IntegerGeneratorField
const zIntegerGeneratorStartEndStep = z.object({
type: z.literal('integer_generator_start_end_step').default('integer_generator_start_end_step'),
start: z.number().int().default(0),
end: z.number().int().default(10),
step: z.number().int().default(1),
values: z.array(z.number().int()).nullish(),
});
export type IntegerGeneratorStartEndStep = z.infer<typeof zIntegerGeneratorStartEndStep>;
export const getIntegerGeneratorStartEndStepDefaults = () => zIntegerGeneratorStartEndStep.parse({});
export const getIntegerGeneratorStartEndStepValues = ({ start, end, step }: IntegerGeneratorStartEndStep) => {
if (step === 0) {
return [];
}
const values = Array.from({ length: Math.floor((end - start) / step) + 1 }, (_, i) => start + i * step);
return values;
};
const zIntegerGeneratorStartCountStep = z.object({
type: z.literal('integer_generator_start_count_step').default('integer_generator_start_count_step'),
start: z.number().int().default(0),
step: z.number().int().default(1),
count: z.number().int().default(10),
values: z.array(z.number().int()).nullish(),
});
export type IntegerGeneratorStartCountStep = z.infer<typeof zIntegerGeneratorStartCountStep>;
export const getIntegerGeneratorStartCountStepDefaults = () => zIntegerGeneratorStartCountStep.parse({});
export const getIntegerGeneratorStartCountStepValues = ({ start, count, step }: IntegerGeneratorStartCountStep) => {
if (step === 0) {
return [];
}
const values = Array.from({ length: count }, (_, i) => start + i * step);
return values;
};
const zIntegerGeneratorRandom = z.object({
type: z.literal('integer_generator_random').default('integer_generator_random'),
min: z.number().int().default(0),
max: z.number().int().default(10),
count: z.number().int().default(10),
values: z.array(z.number().int()).nullish(),
});
export type IntegerGeneratorRandom = z.infer<typeof zIntegerGeneratorRandom>;
export const getIntegerGeneratorRandomDefaults = () => zIntegerGeneratorRandom.parse({});
export const getIntegerGeneratorRandomValues = ({ min, max, count }: IntegerGeneratorRandom) => {
const values = Array.from({ length: count }, () => Math.floor(Math.random() * (max - min + 1) + min));
return values;
};
export const zIntegerGeneratorFieldValue = z.union([
zIntegerGeneratorStartEndStep,
zIntegerGeneratorStartCountStep,
zIntegerGeneratorRandom,
]);
const zIntegerGeneratorFieldInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerGeneratorFieldValue,
});
const zIntegerGeneratorFieldInputTemplate = zFieldInputTemplateBase.extend({
type: zIntegerGeneratorFieldType,
originalType: zFieldType.optional(),
default: zIntegerGeneratorFieldValue,
});
const zIntegerGeneratorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIntegerGeneratorFieldType,
});
export type IntegerGeneratorFieldValue = z.infer<typeof zIntegerGeneratorFieldValue>;
export type IntegerGeneratorFieldInputInstance = z.infer<typeof zIntegerGeneratorFieldInputInstance>;
export type IntegerGeneratorFieldInputTemplate = z.infer<typeof zIntegerGeneratorFieldInputTemplate>;
export const isIntegerGeneratorFieldInputInstance = buildTypeGuard(zIntegerGeneratorFieldInputInstance);
export const isIntegerGeneratorFieldInputTemplate = buildTypeGuard(zIntegerGeneratorFieldInputTemplate);
export const resolveIntegerGeneratorField = ({ value }: IntegerGeneratorFieldInputInstance) => {
if (value.values) {
return value.values;
}
if (value.type === 'integer_generator_start_end_step') {
return getIntegerGeneratorStartEndStepValues(value);
}
if (value.type === 'integer_generator_start_count_step') {
return getIntegerGeneratorStartCountStepValues(value);
}
if (value.type === 'integer_generator_random') {
return getIntegerGeneratorRandomValues(value);
}
assert(false, 'Invalid integer generator type');
};
// #endregion
// #region StatelessField
/**
* StatelessField is a catchall for stateless fields with no UI input components. They do not
@@ -1172,6 +1264,7 @@ export const zStatefulFieldValue = z.union([
zColorFieldValue,
zSchedulerFieldValue,
zFloatGeneratorFieldValue,
zIntegerGeneratorFieldValue,
]);
export type StatefulFieldValue = z.infer<typeof zStatefulFieldValue>;
@@ -1210,6 +1303,7 @@ const zStatefulFieldInputInstance = z.union([
zColorFieldInputInstance,
zSchedulerFieldInputInstance,
zFloatGeneratorFieldInputInstance,
zIntegerGeneratorFieldInputInstance,
]);
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
@@ -1252,6 +1346,7 @@ const zStatefulFieldInputTemplate = z.union([
zSchedulerFieldInputTemplate,
zStatelessFieldInputTemplate,
zFloatGeneratorFieldInputTemplate,
zIntegerGeneratorFieldInputTemplate,
]);
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
@@ -1287,6 +1382,7 @@ const zStatefulFieldOutputTemplate = z.union([
zColorFieldOutputTemplate,
zSchedulerFieldOutputTemplate,
zFloatGeneratorFieldOutputTemplate,
zIntegerGeneratorFieldOutputTemplate,
]);
export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]);

View File

@@ -107,6 +107,7 @@ export const isBatchNode = (node: InvocationNode) => {
export const isGeneratorNode = (node: InvocationNode) => {
switch (node.data.type) {
case 'float_generator':
case 'integer_generator':
return true;
default:
return false;

View File

@@ -20,6 +20,7 @@ import type {
ImageFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IntegerGeneratorFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
MainModelFieldInputTemplate,
@@ -639,14 +640,14 @@ const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerField
};
const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGeneratorFieldInputTemplate> = ({
schemaObject,
// schemaObject,
baseField,
fieldType,
}) => {
const template: FloatGeneratorFieldInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? {
default: {
type: 'float_generator_start_end_step',
start: 0,
end: 1,
@@ -658,6 +659,26 @@ const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGene
return template;
};
const buildIntegerGeneratorFieldInputTemplate: FieldInputTemplateBuilder<IntegerGeneratorFieldInputTemplate> = ({
// schemaObject,
baseField,
fieldType,
}) => {
const template: IntegerGeneratorFieldInputTemplate = {
...baseField,
type: fieldType,
default: {
type: 'integer_generator_start_end_step',
start: 0,
end: 10,
step: 1,
values: undefined,
},
};
return template;
};
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
BoardField: buildBoardFieldInputTemplate,
BooleanField: buildBooleanFieldInputTemplate,
@@ -687,6 +708,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
} as const;
export const buildFieldInputTemplate = (

View File

@@ -130,6 +130,8 @@ export const parseSchema = (
if (type === 'float_batch' && propertyName === 'floats') {
fieldType.batch = true;
} else if (type === 'integer_batch' && propertyName === 'integers') {
fieldType.batch = true;
}
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
@@ -195,6 +197,8 @@ export const parseSchema = (
if (type === 'float_generator' && propertyName === 'floats') {
fieldType.batch = true;
} else if (type === 'integer_generator' && propertyName === 'integers') {
fieldType.batch = true;
}
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);

View File

@@ -27,9 +27,11 @@ import {
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerGeneratorFieldInputInstance,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
resolveFloatGeneratorField,
resolveIntegerGeneratorField,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
@@ -81,19 +83,36 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
const ownValue = batchNode.data.inputs.floats.value;
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
if (!edgeToFloats) {
return ownValue ?? [];
}
const generator = nodes.find((node) => node.id === edgeToFloats.source);
assert(generator, 'Missing edge from float generator to float batch');
assert(isFloatGeneratorFieldInputInstance(generator.data.inputs['generator']), 'Invalid float generator');
const generatorValue = resolveFloatGeneratorField(generator.data.inputs['generator']);
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
assert(generatorNode, 'Missing edge from float generator to float batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
const generatorValue = resolveFloatGeneratorField(generatorField);
return generatorValue;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
const ownValue = batchNode.data.inputs.integers.value;
// no generators for integers yet
return ownValue ?? [];
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
if (!incomers) {
return ownValue ?? [];
}
const generatorNode = nodes.find((node) => node.id === incomers.source);
assert(generatorNode, 'Missing edge from integer generator to integer batch');
const generatorField = generatorNode.data.inputs['generator'];
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
const generatorValue = resolveIntegerGeneratorField(generatorField);
return generatorValue;
}
assert(false, 'Invalid batch node type');
};
@@ -569,30 +588,13 @@ export const selectPromptsCount = createSelector(
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
);
const getBatchCollectionSize = (batchNode: InvocationNode) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
return batchNode.data.inputs.images.value?.length ?? 0;
} else if (batchNode.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
return batchNode.data.inputs.floats.value?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
return batchNode.data.inputs.integers.value?.length ?? 0;
}
return 0;
};
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
return nodes
.filter(isInvocationNode)
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
const invocationNodes = nodes.filter(isInvocationNode);
return invocationNodes
.filter(isBatchNode)
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
.map(getBatchCollectionSize);
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
});
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');