diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts index 2a6ace46f2..584719b1a5 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedNodes.ts @@ -2,7 +2,7 @@ import { logger } from 'app/logging/logger'; import { enqueueRequested } from 'app/store/actions'; import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; import { selectNodesSlice } from 'features/nodes/store/selectors'; -import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field'; +import { isImageFieldCollectionInputInstance, isStringFieldCollectionInputInstance } from 'features/nodes/types/field'; import { isInvocationNode } from 'features/nodes/types/invocation'; import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph'; import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow'; @@ -62,6 +62,34 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) = } } + // Grab string batch nodes for special handling + const stringBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'string_batch'); + for (const node of stringBatchNodes) { + // Satisfy TS + const strings = node.data.inputs['strings']; + if (!isStringFieldCollectionInputInstance(strings)) { + log.warn({ nodeId: node.id }, 'String batch strings field is not a astring collection'); + break; + } + + // Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead + const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value'); + const batchDataCollectionItem: NonNullable[number] = []; + for (const edge of edgesFromStringBatch) { + if (!edge.targetHandle) { + break; + } + batchDataCollectionItem.push({ + node_path: edge.target, + field_name: edge.targetHandle, + items: strings.value, + }); + } + if (batchDataCollectionItem.length > 0) { + data.push(batchDataCollectionItem); + } + } + const batchConfig: BatchConfig = { batch: { graph, diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx index 3874221542..d3190fe26a 100644 --- a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx @@ -1,5 +1,6 @@ import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent'; import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent'; +import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent'; import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance'; import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate'; import { @@ -51,6 +52,8 @@ import { isSDXLRefinerModelFieldInputTemplate, isSpandrelImageToImageModelFieldInputInstance, isSpandrelImageToImageModelFieldInputTemplate, + isStringFieldCollectionInputInstance, + isStringFieldCollectionInputTemplate, isStringFieldInputInstance, isStringFieldInputTemplate, isT2IAdapterModelFieldInputInstance, @@ -97,6 +100,10 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => { const fieldInstance = useFieldInputInstance(nodeId, fieldName); const fieldTemplate = useFieldInputTemplate(nodeId, fieldName); + if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) { + return ; + } + if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) { return ; } diff --git a/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent.tsx b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent.tsx new file mode 100644 index 0000000000..15934bf201 --- /dev/null +++ b/invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent.tsx @@ -0,0 +1,152 @@ +import type { SystemStyleObject } from '@invoke-ai/ui-library'; +import { Box, Flex, Grid, GridItem, IconButton, Textarea } from '@invoke-ai/ui-library'; +import { useAppStore } from 'app/store/nanostores/store'; +import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants'; +import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid'; +import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice'; +import type { + StringFieldCollectionInputInstance, + StringFieldCollectionInputTemplate, +} from 'features/nodes/types/field'; +import { OverlayScrollbarsComponent } from 'overlayscrollbars-react'; +import type { ChangeEvent } from 'react'; +import { memo, useCallback } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiPlusBold, PiXBold } from 'react-icons/pi'; + +import type { FieldComponentProps } from './types'; + +const overlayscrollbarsOptions = getOverlayScrollbarsParams().options; + +const sx = { + borderWidth: 1, + '&[data-error=true]': { + borderColor: 'error.500', + borderStyle: 'solid', + }, +} satisfies SystemStyleObject; + +export const StringFieldCollectionInputComponent = memo( + (props: FieldComponentProps) => { + const { nodeId, field } = props; + const store = useAppStore(); + + const isInvalid = useFieldIsInvalid(nodeId, field.name); + + const onRemoveString = useCallback( + (index: number) => { + const newValue = field.value ? [...field.value] : []; + newValue.splice(index, 1); + store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue })); + }, + [field.name, field.value, nodeId, store] + ); + + const onChangeString = useCallback( + (index: number, value: string) => { + const newValue = field.value ? [...field.value] : []; + newValue[index] = value; + store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue })); + }, + [field.name, field.value, nodeId, store] + ); + + const onAddString = useCallback(() => { + const newValue = field.value ? [...field.value, ''] : ['']; + store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue })); + }, [field.name, field.value, nodeId, store]); + + return ( + + {(!field.value || field.value.length === 0) && ( + + } + variant="ghost" + size="sm" + /> + + )} + {field.value && field.value.length > 0 && ( + + + + } + variant="ghost" + size="sm" + /> + {field.value.map((value, index) => ( + + + + ))} + + + + )} + + ); + } +); + +StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent'; + +type StringListItemContentProps = { + value: string; + index: number; + onRemoveString: (index: number) => void; + onChangeString: (index: number, value: string) => void; +}; + +const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => { + const { t } = useTranslation(); + + const onClickRemove = useCallback(() => { + onRemoveString(index); + }, [index, onRemoveString]); + const onChange = useCallback( + (e: ChangeEvent) => { + onChangeString(index, e.target.value); + }, + [index, onChangeString] + ); + return ( + +