mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support integer generators
This commit is contained in:
@@ -857,9 +857,9 @@
|
||||
"defaultVAE": "Default VAE"
|
||||
},
|
||||
"nodes": {
|
||||
"floatGeneratorStartEndStep": "Start, End, Step",
|
||||
"floatGeneratorStartCountStep": "Start, Count Step",
|
||||
"floatGeneratorRandom": "Random",
|
||||
"startEndStep": "Start, End, Step",
|
||||
"startCountStep": "Start, Count Step",
|
||||
"random": "Random",
|
||||
"noBatchGroup": "no group",
|
||||
"addNode": "Add Node",
|
||||
"addNodeToolTip": "Add Node (Shift+A, Space)",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { FloatGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatGeneratorFieldComponent';
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import { IntegerGeneratorFieldInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorFieldComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
|
||||
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
|
||||
@@ -42,6 +43,8 @@ import {
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isIntegerGeneratorFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
isIPAdapterModelFieldInputTemplate,
|
||||
isLoRAModelFieldInputInstance,
|
||||
@@ -244,6 +247,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <FloatGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIntegerGeneratorFieldInputInstance(fieldInstance) && isIntegerGeneratorFieldInputTemplate(fieldTemplate)) {
|
||||
return <IntegerGeneratorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (fieldTemplate) {
|
||||
// Fallback for when there is no component for the type
|
||||
return null;
|
||||
|
||||
@@ -86,9 +86,9 @@ export const FloatGeneratorFieldInputComponent = memo(
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value="float_generator_start_end_step">{t('nodes.floatGeneratorStartEndStep')}</option>
|
||||
<option value="float_generator_start_count_step">{t('nodes.floatGeneratorStartCountStep')}</option>
|
||||
<option value="float_generator_random">{t('nodes.floatGeneratorRandom')}</option>
|
||||
<option value="float_generator_start_end_step">{t('nodes.startEndStep')}</option>
|
||||
<option value="float_generator_start_count_step">{t('nodes.startCountStep')}</option>
|
||||
<option value="float_generator_random">{t('nodes.random')}</option>
|
||||
</Select>
|
||||
<IconButton aria-label="Reset" icon={<PiPencilSimpleFill />} variant="ghost" />
|
||||
</Flex>
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
import { Flex, IconButton, Select, Text } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { IntegerGeneratorRandomSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorRandomSettings';
|
||||
import { IntegerGeneratorStartCountStepSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorStartCountStepSettings';
|
||||
import { IntegerGeneratorStartEndStepSettings } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/IntegerGeneratorStartEndStepSettings';
|
||||
import { fieldIntegerGeneratorValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
getIntegerGeneratorRandomDefaults,
|
||||
getIntegerGeneratorStartCountStepDefaults,
|
||||
getIntegerGeneratorStartEndStepDefaults,
|
||||
type IntegerGeneratorFieldInputInstance,
|
||||
type IntegerGeneratorFieldInputTemplate,
|
||||
resolveIntegerGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import { round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPencilSimpleFill } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const getDefaultValue = (generatorType: string) => {
|
||||
if (generatorType === 'integer_generator_start_end_step') {
|
||||
return getIntegerGeneratorStartEndStepDefaults();
|
||||
}
|
||||
if (generatorType === 'integer_generator_start_count_step') {
|
||||
return getIntegerGeneratorStartCountStepDefaults();
|
||||
}
|
||||
if (generatorType === 'integer_generator_random') {
|
||||
return getIntegerGeneratorRandomDefaults();
|
||||
}
|
||||
return null;
|
||||
};
|
||||
|
||||
export const IntegerGeneratorFieldInputComponent = memo(
|
||||
(props: FieldComponentProps<IntegerGeneratorFieldInputInstance, IntegerGeneratorFieldInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const onChange = useCallback(
|
||||
(value: IntegerGeneratorFieldInputInstance['value']) => {
|
||||
dispatch(
|
||||
fieldIntegerGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const onChangeGeneratorType = useCallback(
|
||||
(e: ChangeEvent<HTMLSelectElement>) => {
|
||||
const value = getDefaultValue(e.target.value);
|
||||
if (!value) {
|
||||
return;
|
||||
}
|
||||
dispatch(
|
||||
fieldIntegerGeneratorValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
const resolvedValues = useMemo(() => resolveIntegerGeneratorField(field), [field]);
|
||||
const resolvedValuesAsString = useMemo(() => {
|
||||
if (resolvedValues.length === 0) {
|
||||
return '<empty>';
|
||||
} else {
|
||||
return resolvedValues.map((val) => round(val, 2)).join(', ');
|
||||
}
|
||||
}, [resolvedValues]);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" gap={2}>
|
||||
<Flex gap={2}>
|
||||
<Select className="nowheel nodrag" onChange={onChangeGeneratorType} value={field.value.type} size="sm">
|
||||
<option value="integer_generator_start_end_step">{t('nodes.startEndStep')}</option>
|
||||
<option value="integer_generator_start_count_step">{t('nodes.startCountStep')}</option>
|
||||
<option value="integer_generator_random">{t('nodes.random')}</option>
|
||||
</Select>
|
||||
<IconButton aria-label="Reset" icon={<PiPencilSimpleFill />} variant="ghost" />
|
||||
</Flex>
|
||||
{field.value.type === 'integer_generator_start_end_step' && (
|
||||
<IntegerGeneratorStartEndStepSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === 'integer_generator_start_count_step' && (
|
||||
<IntegerGeneratorStartCountStepSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{field.value.type === 'integer_generator_random' && (
|
||||
<IntegerGeneratorRandomSettings state={field.value} onChange={onChange} />
|
||||
)}
|
||||
{/* We don't show previews for random generators, bc they are non-deterministic */}
|
||||
{field.value.type !== 'integer_generator_random' && (
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base" maxH={128}>
|
||||
<Flex w="full" h="auto">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nodrag nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Text className="nodrag nowheel" fontFamily="monospace" userSelect="text" cursor="text">
|
||||
{resolvedValuesAsString}
|
||||
</Text>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
IntegerGeneratorFieldInputComponent.displayName = 'IntegerGeneratorFieldInputComponent';
|
||||
@@ -0,0 +1,55 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import type { IntegerGeneratorRandom } from 'features/nodes/types/field';
|
||||
import { getIntegerGeneratorRandomDefaults } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type IntegerGeneratorRandomSettingsProps = {
|
||||
state: IntegerGeneratorRandom;
|
||||
onChange: (state: IntegerGeneratorRandom) => void;
|
||||
};
|
||||
export const IntegerGeneratorRandomSettings = memo(({ state, onChange }: IntegerGeneratorRandomSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeMin = useCallback(
|
||||
(min: number) => {
|
||||
onChange({ ...state, min });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeMax = useCallback(
|
||||
(max: number) => {
|
||||
onChange({ ...state, max });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getIntegerGeneratorRandomDefaults());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.min')}</FormLabel>
|
||||
<CompositeNumberInput value={state.min} onChange={onChangeMin} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.max')}</FormLabel>
|
||||
<CompositeNumberInput value={state.max} onChange={onChangeMax} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
IntegerGeneratorRandomSettings.displayName = 'IntegerGeneratorRandomSettings';
|
||||
@@ -0,0 +1,59 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
getIntegerGeneratorStartCountStepDefaults,
|
||||
type IntegerGeneratorStartCountStep,
|
||||
} from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type IntegerGeneratorStartCountStepSettingsProps = {
|
||||
state: IntegerGeneratorStartCountStep;
|
||||
onChange: (state: IntegerGeneratorStartCountStep) => void;
|
||||
};
|
||||
export const IntegerGeneratorStartCountStepSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorStartCountStepSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getIntegerGeneratorStartCountStepDefaults());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorStartCountStepSettings.displayName = 'IntegerGeneratorStartCountStepSettings';
|
||||
@@ -0,0 +1,56 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import { getIntegerGeneratorStartEndStepDefaults, type IntegerGeneratorStartEndStep } from 'features/nodes/types/field';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type IntegerGeneratorStartEndStepSettingsProps = {
|
||||
state: IntegerGeneratorStartEndStep;
|
||||
onChange: (state: IntegerGeneratorStartEndStep) => void;
|
||||
};
|
||||
export const IntegerGeneratorStartEndStepSettings = memo(
|
||||
({ state, onChange }: IntegerGeneratorStartEndStepSettingsProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeEnd = useCallback(
|
||||
(end: number) => {
|
||||
onChange({ ...state, end });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getIntegerGeneratorStartEndStepDefaults());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={2} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.end')}</FormLabel>
|
||||
<CompositeNumberInput value={state.end} onChange={onChangeEnd} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} />
|
||||
</FormControl>
|
||||
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
IntegerGeneratorStartEndStepSettings.displayName = 'IntegerGeneratorStartEndStepSettings';
|
||||
@@ -22,6 +22,7 @@ import type {
|
||||
ImageFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IntegerGeneratorFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
MainModelFieldValue,
|
||||
@@ -54,6 +55,7 @@ import {
|
||||
zImageFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
zMainModelFieldValue,
|
||||
@@ -399,6 +401,9 @@ export const nodesSlice = createSlice({
|
||||
fieldFloatGeneratorValueChanged: (state, action: FieldValueAction<FloatGeneratorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zFloatGeneratorFieldValue);
|
||||
},
|
||||
fieldIntegerGeneratorValueChanged: (state, action: FieldValueAction<IntegerGeneratorFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerGeneratorFieldValue);
|
||||
},
|
||||
notesNodeValueChanged: (state, action: PayloadAction<{ nodeId: string; value: string }>) => {
|
||||
const { nodeId, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
@@ -464,6 +469,7 @@ export const {
|
||||
fieldControlLoRAModelValueChanged,
|
||||
fieldFluxVAEModelValueChanged,
|
||||
fieldFloatGeneratorValueChanged,
|
||||
fieldIntegerGeneratorValueChanged,
|
||||
nodeEditorReset,
|
||||
nodeIsIntermediateChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
||||
@@ -216,6 +216,10 @@ const zFloatGeneratorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatGeneratorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIntegerGeneratorFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerGeneratorField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStatefulFieldType = z.union([
|
||||
zIntegerFieldType,
|
||||
zFloatFieldType,
|
||||
@@ -245,6 +249,7 @@ const zStatefulFieldType = z.union([
|
||||
zColorFieldType,
|
||||
zSchedulerFieldType,
|
||||
zFloatGeneratorFieldType,
|
||||
zIntegerGeneratorFieldType,
|
||||
]);
|
||||
export type StatefulFieldType = z.infer<typeof zStatefulFieldType>;
|
||||
const statefulFieldTypeNames = zStatefulFieldType.options.map((o) => o.shape.name.value);
|
||||
@@ -1093,6 +1098,93 @@ export const resolveFloatGeneratorField = ({ value }: FloatGeneratorFieldInputIn
|
||||
};
|
||||
// #endregion
|
||||
|
||||
// #region IntegerGeneratorField
|
||||
const zIntegerGeneratorStartEndStep = z.object({
|
||||
type: z.literal('integer_generator_start_end_step').default('integer_generator_start_end_step'),
|
||||
start: z.number().int().default(0),
|
||||
end: z.number().int().default(10),
|
||||
step: z.number().int().default(1),
|
||||
values: z.array(z.number().int()).nullish(),
|
||||
});
|
||||
export type IntegerGeneratorStartEndStep = z.infer<typeof zIntegerGeneratorStartEndStep>;
|
||||
export const getIntegerGeneratorStartEndStepDefaults = () => zIntegerGeneratorStartEndStep.parse({});
|
||||
export const getIntegerGeneratorStartEndStepValues = ({ start, end, step }: IntegerGeneratorStartEndStep) => {
|
||||
if (step === 0) {
|
||||
return [];
|
||||
}
|
||||
const values = Array.from({ length: Math.floor((end - start) / step) + 1 }, (_, i) => start + i * step);
|
||||
return values;
|
||||
};
|
||||
|
||||
const zIntegerGeneratorStartCountStep = z.object({
|
||||
type: z.literal('integer_generator_start_count_step').default('integer_generator_start_count_step'),
|
||||
start: z.number().int().default(0),
|
||||
step: z.number().int().default(1),
|
||||
count: z.number().int().default(10),
|
||||
values: z.array(z.number().int()).nullish(),
|
||||
});
|
||||
export type IntegerGeneratorStartCountStep = z.infer<typeof zIntegerGeneratorStartCountStep>;
|
||||
export const getIntegerGeneratorStartCountStepDefaults = () => zIntegerGeneratorStartCountStep.parse({});
|
||||
export const getIntegerGeneratorStartCountStepValues = ({ start, count, step }: IntegerGeneratorStartCountStep) => {
|
||||
if (step === 0) {
|
||||
return [];
|
||||
}
|
||||
const values = Array.from({ length: count }, (_, i) => start + i * step);
|
||||
return values;
|
||||
};
|
||||
|
||||
const zIntegerGeneratorRandom = z.object({
|
||||
type: z.literal('integer_generator_random').default('integer_generator_random'),
|
||||
min: z.number().int().default(0),
|
||||
max: z.number().int().default(10),
|
||||
count: z.number().int().default(10),
|
||||
values: z.array(z.number().int()).nullish(),
|
||||
});
|
||||
export type IntegerGeneratorRandom = z.infer<typeof zIntegerGeneratorRandom>;
|
||||
export const getIntegerGeneratorRandomDefaults = () => zIntegerGeneratorRandom.parse({});
|
||||
export const getIntegerGeneratorRandomValues = ({ min, max, count }: IntegerGeneratorRandom) => {
|
||||
const values = Array.from({ length: count }, () => Math.floor(Math.random() * (max - min + 1) + min));
|
||||
return values;
|
||||
};
|
||||
|
||||
export const zIntegerGeneratorFieldValue = z.union([
|
||||
zIntegerGeneratorStartEndStep,
|
||||
zIntegerGeneratorStartCountStep,
|
||||
zIntegerGeneratorRandom,
|
||||
]);
|
||||
const zIntegerGeneratorFieldInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerGeneratorFieldValue,
|
||||
});
|
||||
const zIntegerGeneratorFieldInputTemplate = zFieldInputTemplateBase.extend({
|
||||
type: zIntegerGeneratorFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIntegerGeneratorFieldValue,
|
||||
});
|
||||
const zIntegerGeneratorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerGeneratorFieldType,
|
||||
});
|
||||
export type IntegerGeneratorFieldValue = z.infer<typeof zIntegerGeneratorFieldValue>;
|
||||
export type IntegerGeneratorFieldInputInstance = z.infer<typeof zIntegerGeneratorFieldInputInstance>;
|
||||
export type IntegerGeneratorFieldInputTemplate = z.infer<typeof zIntegerGeneratorFieldInputTemplate>;
|
||||
export const isIntegerGeneratorFieldInputInstance = buildTypeGuard(zIntegerGeneratorFieldInputInstance);
|
||||
export const isIntegerGeneratorFieldInputTemplate = buildTypeGuard(zIntegerGeneratorFieldInputTemplate);
|
||||
export const resolveIntegerGeneratorField = ({ value }: IntegerGeneratorFieldInputInstance) => {
|
||||
if (value.values) {
|
||||
return value.values;
|
||||
}
|
||||
if (value.type === 'integer_generator_start_end_step') {
|
||||
return getIntegerGeneratorStartEndStepValues(value);
|
||||
}
|
||||
if (value.type === 'integer_generator_start_count_step') {
|
||||
return getIntegerGeneratorStartCountStepValues(value);
|
||||
}
|
||||
if (value.type === 'integer_generator_random') {
|
||||
return getIntegerGeneratorRandomValues(value);
|
||||
}
|
||||
assert(false, 'Invalid integer generator type');
|
||||
};
|
||||
// #endregion
|
||||
|
||||
// #region StatelessField
|
||||
/**
|
||||
* StatelessField is a catchall for stateless fields with no UI input components. They do not
|
||||
@@ -1172,6 +1264,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zColorFieldValue,
|
||||
zSchedulerFieldValue,
|
||||
zFloatGeneratorFieldValue,
|
||||
zIntegerGeneratorFieldValue,
|
||||
]);
|
||||
export type StatefulFieldValue = z.infer<typeof zStatefulFieldValue>;
|
||||
|
||||
@@ -1210,6 +1303,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zColorFieldInputInstance,
|
||||
zSchedulerFieldInputInstance,
|
||||
zFloatGeneratorFieldInputInstance,
|
||||
zIntegerGeneratorFieldInputInstance,
|
||||
]);
|
||||
|
||||
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
|
||||
@@ -1252,6 +1346,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zSchedulerFieldInputTemplate,
|
||||
zStatelessFieldInputTemplate,
|
||||
zFloatGeneratorFieldInputTemplate,
|
||||
zIntegerGeneratorFieldInputTemplate,
|
||||
]);
|
||||
|
||||
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
|
||||
@@ -1287,6 +1382,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zColorFieldOutputTemplate,
|
||||
zSchedulerFieldOutputTemplate,
|
||||
zFloatGeneratorFieldOutputTemplate,
|
||||
zIntegerGeneratorFieldOutputTemplate,
|
||||
]);
|
||||
|
||||
export const zFieldOutputTemplate = z.union([zStatefulFieldOutputTemplate, zStatelessFieldOutputTemplate]);
|
||||
|
||||
@@ -107,6 +107,7 @@ export const isBatchNode = (node: InvocationNode) => {
|
||||
export const isGeneratorNode = (node: InvocationNode) => {
|
||||
switch (node.data.type) {
|
||||
case 'float_generator':
|
||||
case 'integer_generator':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
||||
@@ -20,6 +20,7 @@ import type {
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IntegerGeneratorFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
MainModelFieldInputTemplate,
|
||||
@@ -639,14 +640,14 @@ const buildSchedulerFieldInputTemplate: FieldInputTemplateBuilder<SchedulerField
|
||||
};
|
||||
|
||||
const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGeneratorFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
// schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatGeneratorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? {
|
||||
default: {
|
||||
type: 'float_generator_start_end_step',
|
||||
start: 0,
|
||||
end: 1,
|
||||
@@ -658,6 +659,26 @@ const buildFloatGeneratorFieldInputTemplate: FieldInputTemplateBuilder<FloatGene
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerGeneratorFieldInputTemplate: FieldInputTemplateBuilder<IntegerGeneratorFieldInputTemplate> = ({
|
||||
// schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IntegerGeneratorFieldInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: {
|
||||
type: 'integer_generator_start_end_step',
|
||||
start: 0,
|
||||
end: 10,
|
||||
step: 1,
|
||||
values: undefined,
|
||||
},
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputTemplateBuilder> = {
|
||||
BoardField: buildBoardFieldInputTemplate,
|
||||
BooleanField: buildBooleanFieldInputTemplate,
|
||||
@@ -687,6 +708,7 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
|
||||
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
|
||||
ControlLoRAModelField: buildControlLoRAModelFieldInputTemplate,
|
||||
FloatGeneratorField: buildFloatGeneratorFieldInputTemplate,
|
||||
IntegerGeneratorField: buildIntegerGeneratorFieldInputTemplate,
|
||||
} as const;
|
||||
|
||||
export const buildFieldInputTemplate = (
|
||||
|
||||
@@ -130,6 +130,8 @@ export const parseSchema = (
|
||||
|
||||
if (type === 'float_batch' && propertyName === 'floats') {
|
||||
fieldType.batch = true;
|
||||
} else if (type === 'integer_batch' && propertyName === 'integers') {
|
||||
fieldType.batch = true;
|
||||
}
|
||||
|
||||
const fieldInputTemplate = buildFieldInputTemplate(property, propertyName, fieldType);
|
||||
@@ -195,6 +197,8 @@ export const parseSchema = (
|
||||
|
||||
if (type === 'float_generator' && propertyName === 'floats') {
|
||||
fieldType.batch = true;
|
||||
} else if (type === 'integer_generator' && propertyName === 'integers') {
|
||||
fieldType.batch = true;
|
||||
}
|
||||
|
||||
const fieldOutputTemplate = buildFieldOutputTemplate(property, propertyName, fieldType);
|
||||
|
||||
@@ -27,9 +27,11 @@ import {
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerGeneratorFieldInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
resolveFloatGeneratorField,
|
||||
resolveIntegerGeneratorField,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
@@ -81,19 +83,36 @@ export const resolveBatchValue = (batchNode: InvocationNode, nodes: InvocationNo
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
const ownValue = batchNode.data.inputs.floats.value;
|
||||
const edgeToFloats = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'floats');
|
||||
|
||||
if (!edgeToFloats) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
const generator = nodes.find((node) => node.id === edgeToFloats.source);
|
||||
assert(generator, 'Missing edge from float generator to float batch');
|
||||
assert(isFloatGeneratorFieldInputInstance(generator.data.inputs['generator']), 'Invalid float generator');
|
||||
const generatorValue = resolveFloatGeneratorField(generator.data.inputs['generator']);
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === edgeToFloats.source);
|
||||
assert(generatorNode, 'Missing edge from float generator to float batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isFloatGeneratorFieldInputInstance(generatorField), 'Invalid float generator');
|
||||
|
||||
const generatorValue = resolveFloatGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
const ownValue = batchNode.data.inputs.integers.value;
|
||||
// no generators for integers yet
|
||||
return ownValue ?? [];
|
||||
const incomers = edges.find((edge) => edge.target === batchNode.id && edge.targetHandle === 'integers');
|
||||
|
||||
if (!incomers) {
|
||||
return ownValue ?? [];
|
||||
}
|
||||
|
||||
const generatorNode = nodes.find((node) => node.id === incomers.source);
|
||||
assert(generatorNode, 'Missing edge from integer generator to integer batch');
|
||||
|
||||
const generatorField = generatorNode.data.inputs['generator'];
|
||||
assert(isIntegerGeneratorFieldInputInstance(generatorField), 'Invalid integer generator field');
|
||||
|
||||
const generatorValue = resolveIntegerGeneratorField(generatorField);
|
||||
return generatorValue;
|
||||
}
|
||||
assert(false, 'Invalid batch node type');
|
||||
};
|
||||
@@ -569,30 +588,13 @@ export const selectPromptsCount = createSelector(
|
||||
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
|
||||
);
|
||||
|
||||
const getBatchCollectionSize = (batchNode: InvocationNode) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
return batchNode.data.inputs.images.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
|
||||
return batchNode.data.inputs.strings.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
return batchNode.data.inputs.floats.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
return batchNode.data.inputs.integers.value?.length ?? 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
|
||||
return nodes
|
||||
.filter(isInvocationNode)
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes, edges }) => {
|
||||
const invocationNodes = nodes.filter(isInvocationNode);
|
||||
return invocationNodes
|
||||
.filter(isBatchNode)
|
||||
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
|
||||
.map(getBatchCollectionSize);
|
||||
.map((batchNodes) => resolveBatchValue(batchNodes, invocationNodes, edges).length);
|
||||
});
|
||||
|
||||
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
|
||||
|
||||
Reference in New Issue
Block a user