feat(ui): support string batches

This commit is contained in:
psychedelicious
2025-01-10 10:32:23 +10:00
parent e077fe8046
commit b52b271dc4
9 changed files with 365 additions and 7 deletions

View File

@@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isImageFieldCollectionInputInstance, isStringFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
@@ -62,6 +62,34 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
}
}
// Grab string batch nodes for special handling
const stringBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a astring collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromStringBatch) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: strings.value,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
}
const batchConfig: BatchConfig = {
batch: {
graph,

View File

@@ -1,5 +1,6 @@
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -51,6 +52,8 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
@@ -97,6 +100,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
}

View File

@@ -0,0 +1,152 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Grid, GridItem, IconButton, Textarea } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StringFieldCollectionInputInstance,
StringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const StringFieldCollectionInputComponent = memo(
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
const { nodeId, field } = props;
const store = useAppStore();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const onRemoveString = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeString = useCallback(
(index: number, value: string) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddString = useCallback(() => {
const newValue = field.value ? [...field.value, ''] : [''];
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="full"
maxH={64}
alignItems="stretch"
justifyContent="center"
>
{(!field.value || field.value.length === 0) && (
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
<IconButton
w="full"
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
</Box>
)}
{field.value && field.value.length > 0 && (
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
<IconButton
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<StringListItemContent
value={value}
index={index}
onRemoveString={onRemoveString}
onChangeString={onChangeString}
/>
</GridItem>
))}
</Grid>
</OverlayScrollbarsComponent>
</Box>
)}
</Flex>
);
}
);
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Textarea size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';

View File

@@ -3,7 +3,12 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { useMemo } from 'react';
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
@@ -35,8 +40,20 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
// Else special handling for individual field types
// Image collections may have min or max item counts
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
// Image collections may have min or max item counts
if (template.minItems !== undefined && field.value.length < template.minItems) {
return true;
}
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
return true;
}
}
// String collections may have min or max item counts
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
if (template.minItems !== undefined && field.value.length < template.minItems) {
return true;
}

View File

@@ -28,6 +28,7 @@ import type {
SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
@@ -56,6 +57,7 @@ import {
zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
@@ -311,6 +313,9 @@ export const nodesSlice = createSlice({
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
},
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
fieldValueReducer(state, action, zStringFieldCollectionValue);
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
@@ -438,6 +443,7 @@ export const {
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
@@ -549,6 +555,7 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,

View File

@@ -86,6 +86,15 @@ const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zStringCollectionFieldType = z.object({
name: z.literal('StringField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isStringCollectionFieldType = (
fieldType: FieldType
): fieldType is z.infer<typeof zStringCollectionFieldType> => zStringCollectionFieldType.safeParse(fieldType).success;
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
@@ -315,6 +324,52 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType,
});
// #region StringField Collection
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldCollectionValue,
});
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zStringCollectionFieldType,
originalType: zFieldType.optional(),
default: zStringFieldCollectionValue,
maxLength: z.number().int().gte(0).optional(),
minLength: z.number().int().gte(0).optional(),
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
})
.refine(
(val) => {
if (val.maxLength !== undefined && val.minLength !== undefined) {
return val.maxLength >= val.minLength;
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
)
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringCollectionFieldType,
});
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
export const isStringFieldCollectionInputInstance = (val: unknown): val is StringFieldCollectionInputInstance =>
zStringFieldCollectionInputInstance.safeParse(val).success;
export const isStringFieldCollectionInputTemplate = (val: unknown): val is StringFieldCollectionInputTemplate =>
zStringFieldCollectionInputTemplate.safeParse(val).success;
// #endregion
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
@@ -965,6 +1020,7 @@ export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
zFloatFieldValue,
zStringFieldValue,
zStringFieldCollectionValue,
zBooleanFieldValue,
zEnumFieldValue,
zImageFieldValue,
@@ -1002,6 +1058,7 @@ const zStatefulFieldInputInstance = z.union([
zIntegerFieldInputInstance,
zFloatFieldInputInstance,
zStringFieldInputInstance,
zStringFieldCollectionInputInstance,
zBooleanFieldInputInstance,
zEnumFieldInputInstance,
zImageFieldInputInstance,
@@ -1037,6 +1094,7 @@ const zStatefulFieldInputTemplate = z.union([
zIntegerFieldInputTemplate,
zFloatFieldInputTemplate,
zStringFieldInputTemplate,
zStringFieldCollectionInputTemplate,
zBooleanFieldInputTemplate,
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
@@ -1076,6 +1134,7 @@ const zStatefulFieldOutputTemplate = z.union([
zIntegerFieldOutputTemplate,
zFloatFieldOutputTemplate,
zStringFieldOutputTemplate,
zStringFieldCollectionOutputTemplate,
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,

View File

@@ -1,5 +1,6 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import type { InvocationNode } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
@@ -7,6 +8,17 @@ import { v4 as uuidv4 } from 'uuid';
const log = logger('workflows');
// These nodes are not executable, they exist for the frontend only
const filterNonExecutableNodes = (node: InvocationNode) => {
if (node.data.type === 'image_batch') {
return false;
}
if (node.data.type === 'string_batch') {
return false;
}
return true;
};
/**
* Builds a graph from the node editor state.
*/
@@ -14,7 +26,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
const filteredNodes = nodes.filter(isInvocationNode).filter(filterNonExecutableNodes);
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {

View File

@@ -28,12 +28,17 @@ import type {
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
import {
isImageCollectionFieldType,
isStatefulFieldType,
isStringCollectionFieldType,
} from 'features/nodes/types/field';
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
import { isSchemaObject } from 'features/nodes/types/openapi';
import { t } from 'i18next';
@@ -133,6 +138,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
return template;
};
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StringFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
return template;
};
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -569,12 +604,17 @@ export const buildFieldInputTemplate = (
if (isStatefulFieldType(fieldType)) {
if (isImageCollectionFieldType(fieldType)) {
fieldType;
return buildImageFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isStringCollectionFieldType(fieldType)) {
return buildStringFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
const template = builder({

View File

@@ -18,7 +18,12 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, Templates } from 'features/nodes/store/types';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
@@ -130,6 +135,37 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
});
return;
}
} else if (
field.value &&
isStringFieldCollectionInputInstance(field) &&
isStringFieldCollectionInputTemplate(fieldTemplate)
) {
// String collections may have min or max items to validate
// TODO(psyche): generalize this to other collection types
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
return;
}
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooFewItems', {
...baseTKeyOptions,
size: field.value.length,
minItems: fieldTemplate.minItems,
}),
});
return;
}
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooManyItems', {
...baseTKeyOptions,
size: field.value.length,
maxItems: fieldTemplate.maxItems,
}),
});
return;
}
}
});
});