mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
feat(ui): support float batches
This commit is contained in:
@@ -104,6 +104,21 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
addBatchDataCollectionItem(edgesFromStringBatch, integers.value);
|
||||
}
|
||||
|
||||
// Grab float batch nodes for special handling
|
||||
const floatBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'float_batch');
|
||||
for (const node of floatBatchNodes) {
|
||||
// Satisfy TS
|
||||
const floats = node.data.inputs['floats'];
|
||||
if (!isIntegerFieldCollectionInputInstance(floats)) {
|
||||
log.warn({ nodeId: node.id }, 'Integer batch floats field is not a float collection');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
addBatchDataCollectionItem(edgesFromStringBatch, floats.value);
|
||||
}
|
||||
|
||||
const batchConfig: BatchConfig = {
|
||||
batch: {
|
||||
graph,
|
||||
|
||||
@@ -23,6 +23,8 @@ import {
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
@@ -115,16 +117,22 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (
|
||||
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/comp
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
@@ -28,7 +30,11 @@ const sx = {
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const NumberFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>) => {
|
||||
(
|
||||
props:
|
||||
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
|
||||
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const store = useAppStore();
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
@@ -13,7 +15,7 @@ import {
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateIntegerFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import { useMemo } from 'react';
|
||||
@@ -61,7 +63,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
|
||||
if (validateIntegerFieldCollectionValue(field.value, template).length > 0) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,6 +45,7 @@ import {
|
||||
zControlLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
@@ -322,7 +323,7 @@ export const nodesSlice = createSlice({
|
||||
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
|
||||
},
|
||||
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue);
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
|
||||
},
|
||||
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBooleanFieldValue);
|
||||
|
||||
@@ -91,6 +91,15 @@ const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFloatCollectionFieldType = z.object({
|
||||
name: z.literal('FloatField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isFloatCollectionFieldType = (
|
||||
fieldType: FieldType
|
||||
): fieldType is z.infer<typeof zFloatCollectionFieldType> => zFloatCollectionFieldType.safeParse(fieldType).success;
|
||||
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -346,6 +355,46 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
|
||||
zFloatFieldInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region FloatField Collection
|
||||
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
|
||||
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldCollectionValue,
|
||||
});
|
||||
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFloatFieldCollectionValue,
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().optional(),
|
||||
exclusiveMaximum: z.number().optional(),
|
||||
minimum: z.number().optional(),
|
||||
exclusiveMinimum: z.number().optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
});
|
||||
export type FloatFieldCollectionValue = z.infer<typeof zFloatFieldCollectionValue>;
|
||||
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
|
||||
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
|
||||
export const isFloatFieldCollectionInputInstance = (val: unknown): val is FloatFieldCollectionInputInstance =>
|
||||
zFloatFieldCollectionInputInstance.safeParse(val).success;
|
||||
export const isFloatFieldCollectionInputTemplate = (val: unknown): val is FloatFieldCollectionInputTemplate =>
|
||||
zFloatFieldCollectionInputTemplate.safeParse(val).success;
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
|
||||
export const zStringFieldValue = z.string();
|
||||
@@ -1069,6 +1118,7 @@ export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zBooleanFieldValue,
|
||||
@@ -1108,6 +1158,7 @@ const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zIntegerFieldCollectionInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zFloatFieldCollectionInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zStringFieldCollectionInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
@@ -1145,6 +1196,7 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zIntegerFieldCollectionInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zFloatFieldCollectionInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zStringFieldCollectionInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
@@ -1186,6 +1238,7 @@ const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zIntegerFieldCollectionOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zFloatFieldCollectionOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zStringFieldCollectionOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import type {
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldCollectionValue,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionValue,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
@@ -56,9 +58,9 @@ export const validateStringFieldCollectionValue = (
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const validateIntegerFieldCollectionValue = (
|
||||
value: NonNullable<IntegerFieldCollectionValue>,
|
||||
template: IntegerFieldCollectionInputTemplate
|
||||
export const validateNumberFieldCollectionValue = (
|
||||
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
|
||||
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum } = template;
|
||||
|
||||
@@ -19,6 +19,9 @@ const filterNonExecutableNodes = (node: InvocationNode) => {
|
||||
if (node.data.type === 'integer_batch') {
|
||||
return false;
|
||||
}
|
||||
if (node.data.type === 'float_batch') {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
EnumFieldInputTemplate,
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
@@ -36,6 +37,7 @@ import type {
|
||||
VAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatCollectionFieldType,
|
||||
isImageCollectionFieldType,
|
||||
isIntegerCollectionFieldType,
|
||||
isStatefulFieldType,
|
||||
@@ -160,6 +162,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -665,6 +709,12 @@ export const buildFieldInputTemplate = (
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isFloatCollectionFieldType(fieldType)) {
|
||||
return buildFloatFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
const template = builder({
|
||||
|
||||
@@ -19,6 +19,8 @@ import type { NodesState, Templates } from 'features/nodes/store/types';
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
@@ -28,7 +30,7 @@ import {
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateIntegerFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
@@ -126,7 +128,14 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
isIntegerFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateIntegerFieldCollectionValue(field.value, fieldTemplate);
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
isFloatFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user