feat(ui): support float batches

This commit is contained in:
psychedelicious
2025-01-10 14:56:35 +10:00
parent c26b3cd54f
commit bfe6d98cba
10 changed files with 168 additions and 13 deletions

View File

@@ -104,6 +104,21 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
addBatchDataCollectionItem(edgesFromStringBatch, integers.value);
}
// Grab float batch nodes for special handling
const floatBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isIntegerFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Integer batch floats field is not a float collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, floats.value);
}
const batchConfig: BatchConfig = {
batch: {
graph,

View File

@@ -23,6 +23,8 @@ import {
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
@@ -115,16 +117,22 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -6,6 +6,8 @@ import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/comp
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
@@ -28,7 +30,11 @@ const sx = {
} satisfies SystemStyleObject;
export const NumberFieldCollectionInputComponent = memo(
(props: FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>) => {
(
props:
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const store = useAppStore();

View File

@@ -4,6 +4,8 @@ import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
@@ -13,7 +15,7 @@ import {
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateIntegerFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import { useMemo } from 'react';
@@ -61,7 +63,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
if (validateIntegerFieldCollectionValue(field.value, template).length > 0) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}

View File

@@ -45,6 +45,7 @@ import {
zControlLoRAModelFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldCollectionValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
@@ -322,7 +323,7 @@ export const nodesSlice = createSlice({
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue);
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);

View File

@@ -91,6 +91,15 @@ const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
originalType: zStatelessFieldType.optional(),
});
const zFloatCollectionFieldType = z.object({
name: z.literal('FloatField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isFloatCollectionFieldType = (
fieldType: FieldType
): fieldType is z.infer<typeof zFloatCollectionFieldType> => zFloatCollectionFieldType.safeParse(fieldType).success;
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
@@ -346,6 +355,46 @@ export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputT
zFloatFieldInputTemplate.safeParse(val).success;
// #endregion
// #region FloatField Collection
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldCollectionValue,
});
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zFloatCollectionFieldType,
originalType: zFieldType.optional(),
default: zFloatFieldCollectionValue,
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
multipleOf: z.number().int().optional(),
maximum: z.number().optional(),
exclusiveMaximum: z.number().optional(),
minimum: z.number().optional(),
exclusiveMinimum: z.number().optional(),
})
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatCollectionFieldType,
});
export type FloatFieldCollectionValue = z.infer<typeof zFloatFieldCollectionValue>;
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
export const isFloatFieldCollectionInputInstance = (val: unknown): val is FloatFieldCollectionInputInstance =>
zFloatFieldCollectionInputInstance.safeParse(val).success;
export const isFloatFieldCollectionInputTemplate = (val: unknown): val is FloatFieldCollectionInputTemplate =>
zFloatFieldCollectionInputTemplate.safeParse(val).success;
// #endregion
// #region StringField
export const zStringFieldValue = z.string();
@@ -1069,6 +1118,7 @@ export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
zIntegerFieldCollectionValue,
zFloatFieldValue,
zFloatFieldCollectionValue,
zStringFieldValue,
zStringFieldCollectionValue,
zBooleanFieldValue,
@@ -1108,6 +1158,7 @@ const zStatefulFieldInputInstance = z.union([
zIntegerFieldInputInstance,
zIntegerFieldCollectionInputInstance,
zFloatFieldInputInstance,
zFloatFieldCollectionInputInstance,
zStringFieldInputInstance,
zStringFieldCollectionInputInstance,
zBooleanFieldInputInstance,
@@ -1145,6 +1196,7 @@ const zStatefulFieldInputTemplate = z.union([
zIntegerFieldInputTemplate,
zIntegerFieldCollectionInputTemplate,
zFloatFieldInputTemplate,
zFloatFieldCollectionInputTemplate,
zStringFieldInputTemplate,
zStringFieldCollectionInputTemplate,
zBooleanFieldInputTemplate,
@@ -1186,6 +1238,7 @@ const zStatefulFieldOutputTemplate = z.union([
zIntegerFieldOutputTemplate,
zIntegerFieldCollectionOutputTemplate,
zFloatFieldOutputTemplate,
zFloatFieldCollectionOutputTemplate,
zStringFieldOutputTemplate,
zStringFieldCollectionOutputTemplate,
zBooleanFieldOutputTemplate,

View File

@@ -1,4 +1,6 @@
import type {
FloatFieldCollectionInputTemplate,
FloatFieldCollectionValue,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
IntegerFieldCollectionInputTemplate,
@@ -56,9 +58,9 @@ export const validateStringFieldCollectionValue = (
return reasons;
};
export const validateIntegerFieldCollectionValue = (
value: NonNullable<IntegerFieldCollectionValue>,
template: IntegerFieldCollectionInputTemplate
export const validateNumberFieldCollectionValue = (
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum } = template;

View File

@@ -19,6 +19,9 @@ const filterNonExecutableNodes = (node: InvocationNode) => {
if (node.data.type === 'integer_batch') {
return false;
}
if (node.data.type === 'float_batch') {
return false;
}
return true;
};

View File

@@ -11,6 +11,7 @@ import type {
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldCollectionInputTemplate,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
@@ -36,6 +37,7 @@ import type {
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import {
isFloatCollectionFieldType,
isImageCollectionFieldType,
isIntegerCollectionFieldType,
isStatefulFieldType,
@@ -160,6 +162,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
return template;
};
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FloatFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -665,6 +709,12 @@ export const buildFieldInputTemplate = (
baseField,
fieldType,
});
} else if (isFloatCollectionFieldType(fieldType)) {
return buildFloatFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
const template = builder({

View File

@@ -19,6 +19,8 @@ import type { NodesState, Templates } from 'features/nodes/store/types';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
@@ -28,7 +30,7 @@ import {
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateIntegerFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import { isInvocationNode } from 'features/nodes/types/invocation';
@@ -126,7 +128,14 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
isIntegerFieldCollectionInputInstance(field) &&
isIntegerFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateIntegerFieldCollectionValue(field.value, fieldTemplate);
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isFloatFieldCollectionInputInstance(field) &&
isFloatFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
}
});