refactor(ui): persistent workflow field value generators

Previously, workflow generators existed on a layer above the workflow and were ephemeral. Generators could be run and the result saved to the workflow.

There was no way to have the generator and its settings to be an inherent part of the workflow. When you refresh the page or load a workflow, the generator settings are reset.

For example, a number collection field's value is a list of numbers. When you use a range generator for that field, the generated list of numbers is written to the workflow. When you refresh the page or load the workflow later, all you have is the list of numbers.

This change makes generators a part of the workflow itself. In other words, the a field's generator settings are persisted to the workflow alongside the field, and eligible fields can be thought of as having a generator _as_ their state.

For example, consider a number collection. If the field has a generator enabled, the generator settings are stored in the workflow directly, in that field's state. When we need to access the field's value, if it has a generator, we run the generator. If there is no generator, we get the directly-entered value.

This enables an important use-case, where the workflow editor can set up a good baseline generator and save it to the workflow.

Then the workflow user loads the workflow, and they just see the generator settings, importantly with the default values set by the editor. They never need to see a big list of values.

- Add generator persistence to number collection field values.
- Update all logic that references field values to use "resolved" field values if the field is a number collection field. This includes validation logic.
- Rework the generator UI. Generators are now part of each field, not a separate modal. You can enable the generator, reset its values, commit them and then edit them. Or, disable the generator to manually edit the values.
- Support locking the linear view mode. If the workflow editor locks the field, the linear view will be slimmed down, showing only the generator fields.
- Rework how the "reset to default value" functionality works with exposed fields to also work with generators. **Unfortunately, this did require some changes to redux state that I cannot easily handle in a redux state migration. As a result, on the first run after updating Invoke, their workflow editor state will be erased.**
This commit is contained in:
psychedelicious
2025-01-15 15:30:45 +11:00
parent 42d59b961e
commit 6df3e9f960
25 changed files with 719 additions and 348 deletions

View File

@@ -855,7 +855,13 @@
},
"nodes": {
"noBatchGroup": "no group",
"generator": "Generator",
"generatedValues": "Generated Values",
"commitValues": "Commit Values",
"addValue": "Add Value",
"addNode": "Add Node",
"lockLinearView": "Lock Linear View",
"unlockLinearView": "Unlock Linear View",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
"animatedEdges": "Animated Edges",
@@ -994,11 +1000,7 @@
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
"modelAccessError": "Unable to find model {{key}}, resetting to default",
"saveToGallery": "Save To Gallery",
"addItem": "Add Item",
"generateValues": "Generate Values",
"floatRangeGenerator": "Float Range Generator",
"integerRangeGenerator": "Integer Range Generator"
"saveToGallery": "Save To Gallery"
},
"parameters": {
"aspect": "Aspect",

View File

@@ -22,8 +22,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
@@ -112,8 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
<ImageContextMenu />
<FullscreenDropzone />
<VideosModal />
<FloatRangeGeneratorModal />
<IntegerRangeGeneratorModal />
</ErrorBoundary>
);
};

View File

@@ -9,6 +9,7 @@ import {
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
@@ -140,10 +141,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(integers);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
@@ -163,10 +165,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(floats);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}

View File

@@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: false,
immutableCheck: false,
// serializableCheck: import.meta.env.MODE === 'development',
// immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)

View File

@@ -1,105 +0,0 @@
import {
Button,
CompositeNumberInput,
Flex,
FormControl,
FormLabel,
IconButton,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { round } from 'lodash-es';
import { atom } from 'nanostores';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type FloatRangeGeneratorModalState = {
isOpen: boolean;
onSave: (values: number[]) => void;
};
const $floatRangeGeneratorModal = atom<FloatRangeGeneratorModalState>({
isOpen: false,
onSave: () => {},
});
export const openFloatRangeGeneratorModal = (onSave: (values: number[]) => void) => {
$floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: true, onSave });
};
const onClose = () => {
$floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: false });
};
export const FloatRangeGeneratorModal = memo(() => {
const { isOpen, onSave } = useStore($floatRangeGeneratorModal);
const { t } = useTranslation();
const [start, setStart] = useState(0);
const [step, setStep] = useState(1);
const [count, setCount] = useState(1);
const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
const onReset = useCallback(() => {
setStart(0);
setStep(1);
setCount(1);
}, []);
const onClickSave = useCallback(() => {
onSave(values);
onClose();
}, [onSave, values]);
return (
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>{t('nodes.floatRangeGenerator')}</ModalHeader>
<ModalCloseButton />
<ModalBody minH={200}>
<Flex gap={4} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={start} onChange={setStart} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={count} onChange={setCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={step} onChange={setStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
</Flex>
<Flex w="full" h="auto" flexDir="column" gap={2} pt={4}>
<FormLabel>{t('common.values')}</FormLabel>
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base">
<Text fontFamily="monospace" fontSize="md" color="base.300">
{values.map((val) => round(val, 2)).join(', ')}
</Text>
</Flex>
</Flex>
</ModalBody>
<ModalFooter gap={2}>
<Button onClick={onClose} variant="ghost">
{t('common.cancel')}
</Button>
<Button onClick={onClickSave}>{t('common.save')}</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
});
FloatRangeGeneratorModal.displayName = 'FloatRangeGeneratorModal';

View File

@@ -1,103 +0,0 @@
import {
Button,
CompositeNumberInput,
Flex,
FormControl,
FormLabel,
IconButton,
Modal,
ModalBody,
ModalCloseButton,
ModalContent,
ModalFooter,
ModalHeader,
ModalOverlay,
Text,
} from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { atom } from 'nanostores';
import { memo, useCallback, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type IntegerRangeGeneratorModalState = {
isOpen: boolean;
onSave: (values: number[]) => void;
};
const $integerRangeGeneratorModal = atom<IntegerRangeGeneratorModalState>({
isOpen: false,
onSave: () => {},
});
export const openIntegerRangeGeneratorModal = (onSave: (values: number[]) => void) => {
$integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: true, onSave });
};
const onClose = () => {
$integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: false });
};
export const IntegerRangeGeneratorModal = memo(() => {
const { isOpen, onSave } = useStore($integerRangeGeneratorModal);
const { t } = useTranslation();
const [start, setStart] = useState(0);
const [step, setStep] = useState(1);
const [count, setCount] = useState(1);
const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
const onReset = useCallback(() => {
setStart(0);
setStep(1);
setCount(1);
}, []);
const onClickSave = useCallback(() => {
onSave(values);
onClose();
}, [onSave, values]);
return (
<Modal isOpen={isOpen} onClose={onClose} isCentered>
<ModalOverlay />
<ModalContent>
<ModalHeader>{t('nodes.integerRangeGenerator')}</ModalHeader>
<ModalCloseButton />
<ModalBody minH={200}>
<Flex gap={4} alignItems="flex-end">
<FormControl orientation="vertical">
<FormLabel>{t('common.start')}</FormLabel>
<CompositeNumberInput value={start} onChange={setStart} min={-Infinity} max={Infinity} step={1} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.count')}</FormLabel>
<CompositeNumberInput value={count} onChange={setCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical">
<FormLabel>{t('common.step')}</FormLabel>
<CompositeNumberInput value={step} onChange={setStep} min={-Infinity} max={Infinity} step={1} />
</FormControl>
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
</Flex>
<Flex w="full" h="auto" flexDir="column" gap={2} pt={4}>
<FormLabel>{t('common.values')}</FormLabel>
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base">
<Text fontFamily="monospace" fontSize="md" color="base.300">
{values.join(', ')}
</Text>
</Flex>
</Flex>
</ModalBody>
<ModalFooter gap={2}>
<Button onClick={onClose} variant="ghost">
{t('common.cancel')}
</Button>
<Button onClick={onClickSave}>{t('common.save')}</Button>
</ModalFooter>
</ModalContent>
</Modal>
);
});
IntegerRangeGeneratorModal.displayName = 'IntegerRangeGeneratorModal';

View File

@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
{fieldNames.connectionFields.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
</GridItem>
))}
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
))}
{fieldNames.missingFields.map((fieldName) => (
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
))}
</Flex>

View File

@@ -1,7 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import {
selectWorkflowSlice,
workflowExposedFieldAdded,
@@ -19,7 +19,7 @@ type Props = {
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const value = useFieldValue(nodeId, fieldName);
const field = useFieldInputInstance(nodeId, fieldName);
const selectIsExposed = useMemo(
() =>
createSelector(selectWorkflowSlice, (workflow) => {
@@ -31,8 +31,11 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const isExposed = useAppSelector(selectIsExposed);
const handleExposeField = useCallback(() => {
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
}, [dispatch, fieldName, nodeId, value]);
if (!field) {
return;
}
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
}, [dispatch, field, fieldName, nodeId]);
const handleUnexposeField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));

View File

@@ -14,9 +14,10 @@ import { InputFieldWrapper } from './InputFieldWrapper';
interface Props {
nodeId: string;
fieldName: string;
isLinearView: boolean;
}
const InputField = ({ nodeId, fieldName }: Props) => {
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
@@ -69,12 +70,12 @@ const InputField = ({ nodeId, fieldName }: Props) => {
px={2}
>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex gap={1}>
<Flex gap={1} alignItems="center">
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
</Flex>
</FormControl>

View File

@@ -99,109 +99,285 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
fieldName: string;
isLinearView: boolean;
};
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<StringFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<StringFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BooleanFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<EnumFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ImageFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ImageFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BoardFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ModelIdentifierFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<VAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<T5EncoderModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPLEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPGEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ControlLoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<FluxVAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (
@@ -213,28 +389,64 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<FluxMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SD3MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (fieldTemplate) {

View File

@@ -97,7 +97,11 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
icon={<PiTrashSimpleBold />}
/>
</Flex>
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
isLinearView={true}
/>
</Flex>
</Flex>
<DndListDropIndicator dndState={dndListState} />

View File

@@ -0,0 +1,65 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import {
type FloatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type FloatRangeGeneratorProps = {
state: FloatRangeStartStepCountGenerator;
onChange: (state: FloatRangeStartStepCountGenerator) => void;
};
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getDefaultFloatRangeStartStepCountGenerator());
}, [onChange]);
return (
<Flex gap={1} alignItems="flex-end" p={1}>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<IconButton
onClick={onReset}
aria-label={t('common.reset')}
icon={<PiArrowCounterClockwiseBold />}
variant="ghost"
/>
</Flex>
);
});
FloatRangeGenerator.displayName = 'FloatRangeGenerator';

View File

@@ -1,33 +1,45 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Button,
ButtonGroup,
CompositeNumberInput,
Divider,
Flex,
FormControl,
FormLabel,
Grid,
GridItem,
IconButton,
Switch,
Text,
} from '@invoke-ai/ui-library';
import { NUMPY_RAND_MAX } from 'app/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { openFloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
import { openIntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import {
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionLockLinearViewToggled,
fieldNumberCollectionValueChanged,
} from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { isNil } from 'lodash-es';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
@@ -47,7 +59,7 @@ export const NumberFieldCollectionInputComponent = memo(
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
) => {
const { nodeId, field, fieldTemplate } = props;
const { nodeId, field, fieldTemplate, isLinearView } = props;
const store = useAppStore();
const { t } = useTranslation();
@@ -77,17 +89,6 @@ export const NumberFieldCollectionInputComponent = memo(
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
const onOpenGenerator = useCallback(() => {
const onSave = (values: number[]) => {
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: values }));
};
if (isIntegerField) {
openIntegerRangeGeneratorModal(onSave);
} else {
openFloatRangeGeneratorModal(onSave);
}
}, [field.name, isIntegerField, nodeId, store]);
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
@@ -124,6 +125,32 @@ export const NumberFieldCollectionInputComponent = memo(
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const toggleGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onChangeGenerator = useCallback(
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
},
[field.name, nodeId, store]
);
const onCommitGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onToggleLockLinearView = useCallback(() => {
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const valuesAsString = useMemo(() => {
const resolvedValue = resolveNumberFieldCollectionValue(field);
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
}, [field]);
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
return (
<Flex
className="nodrag"
@@ -140,17 +167,48 @@ export const NumberFieldCollectionInputComponent = memo(
flexDir="column"
gap={1}
>
<ButtonGroup isAttached={false} size="sm" w="full" gap={1}>
<Button onClick={onAddNumber} variant="ghost" w="50%">
{t('nodes.addItem')}
</Button>
<Button onClick={onOpenGenerator} variant="ghost" w="50%">
{t('nodes.generateValues')}
</Button>
</ButtonGroup>
{field.value && field.value.length > 0 && (
<Flex w="full" gap={2}>
{!field.generator && (
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
{t('nodes.addValue')}
</Button>
)}
{field.generator && isLockedOnLinearView && (
<Button
tooltip={
<Flex p={1} flexDir="column">
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
<Text fontFamily="monospace">{valuesAsString}</Text>
</Flex>
}
onClick={onCommitGenerator}
variant="ghost"
flexGrow={1}
size="sm"
>
{t('nodes.commitValues')}
</Button>
)}
{isLockedOnLinearView && (
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
</FormControl>
)}
{!isLinearView && (
<IconButton
onClick={onToggleLockLinearView}
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
variant="ghost"
size="sm"
/>
)}
</Flex>
{!field.generator && field.value && field.value.length > 0 && (
<>
<Divider />
{!(field.lockLinearView && isLinearView) && <Divider />}
<OverlayScrollbarsComponent
className="nowheel"
defer
@@ -176,6 +234,12 @@ export const NumberFieldCollectionInputComponent = memo(
</OverlayScrollbarsComponent>
</>
)}
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
<>
{!(field.lockLinearView && isLinearView) && <Divider />}
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
</>
)}
</Flex>
);
}

View File

@@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
nodeId: string;
field: V;
fieldTemplate: T;
isLinearView: boolean;
};

View File

@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
</Flex>
</Tooltip>
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
</Flex>
);
};

View File

@@ -63,13 +63,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}

View File

@@ -1,8 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@@ -10,19 +11,38 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const selectOriginalExposedFieldValues = useMemo(
() =>
createSelector(
selectWorkflowSlice,
(workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
),
[nodeId, fieldName]
);
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
const value = useFieldValue(nodeId, fieldName);
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
const field = useFieldInputInstance(nodeId, fieldName);
const isValueChanged = useMemo(() => {
if (!field) {
// Field is not found, so it is not changed
return false;
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
return !isEqual(field.value, exposedField?.field.value);
}, [field, exposedField]);
const onReset = useCallback(() => {
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
}, [dispatch, fieldName, nodeId, originalValue]);
if (!exposedField) {
return;
}
const { value } = exposedField.field;
const generator =
isIntegerFieldCollectionInputInstance(exposedField.field) ||
isFloatFieldCollectionInputInstance(exposedField.field)
? exposedField.field.generator
: undefined;
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
}, [dispatch, fieldName, nodeId, exposedField]);
return { originalValue, isValueChanged, onReset };
return { originalValue: exposedField, isValueChanged, onReset };
};

View File

@@ -36,6 +36,8 @@ import type {
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
isFloatFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
@@ -66,6 +68,16 @@ import {
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import {
floatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
getDefaultIntegerRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom, computed } from 'nanostores';
@@ -83,11 +95,22 @@ const initialNodesState: NodesState = {
edges: [],
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
}>;
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
{
nodeId: string;
fieldName: string;
value: T;
} & U
>;
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
return node.data?.inputs[fieldName];
};
const fieldValueReducer = <T extends FieldValue>(
state: NodesState,
@@ -95,17 +118,24 @@ const fieldValueReducer = <T extends FieldValue>(
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
const input = node.data?.inputs[fieldName];
const field = selectField(state, nodeId, fieldName);
const result = schema.safeParse(value);
if (!input || nodeIndex < 0 || !result.success) {
if (!field || !result.success) {
return;
}
input.value = result.data;
field.value = result.data;
// Special handling if the field value is being reset
if (result.data === undefined) {
if (isFloatFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultFloatRangeStartStepCountGenerator();
}
} else if (isIntegerFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
}
}
}
};
export const nodesSlice = createSlice({
@@ -310,8 +340,31 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
fieldValueReset: (
state,
action: FieldValueAction<
StatefulFieldValue,
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
>
) => {
const { nodeId, fieldName, value, generator } = action.payload;
const field = selectField(state, nodeId, fieldName);
const result = zStatefulFieldValue.safeParse(value);
if (!field || !result.success) {
return;
}
field.value = result.data;
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
field.generator = generator;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generator?.type === 'integer-range-generator-start-step-count'
) {
field.generator = generator;
}
},
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
@@ -325,6 +378,85 @@ export const nodesSlice = createSlice({
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (isFloatFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
} else if (isIntegerFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorStateChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
}>
) => {
const { nodeId, fieldName, generatorState } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
generatorState.type === 'float-range-generator-start-step-count'
) {
field.generator = generatorState;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generatorState.type === 'integer-range-generator-start-step-count'
) {
field.generator = generatorState;
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'float-range-generator-start-step-count'
) {
field.value = floatRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'integer-range-generator-start-step-count'
) {
field.value = integerRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else {
// This should never happen
}
},
fieldNumberCollectionLockLinearViewToggled: (
state,
action: PayloadAction<{ nodeId: string; fieldName: string }>
) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
return;
}
field.lockLinearView = !field.lockLinearView;
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);
},
@@ -447,6 +579,10 @@ export const {
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionLockLinearViewToggled,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,

View File

@@ -1,8 +1,8 @@
import type {
FieldIdentifier,
FieldInputInstance,
FieldInputTemplate,
FieldOutputTemplate,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
AnyNode,
@@ -31,15 +31,15 @@ export type NodesState = {
};
export type WorkflowMode = 'edit' | 'view';
export type FieldIdentifierWithValue = FieldIdentifier & {
value: StatefulFieldValue;
export type FieldIdentifierWithInstance = FieldIdentifier & {
field: FieldInputInstance;
};
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
_version: 1;
_version: 2;
isTouched: boolean;
mode: WorkflowMode;
originalExposedFieldValues: FieldIdentifierWithValue[];
originalExposedFieldValues: FieldIdentifierWithInstance[];
searchTerm: string;
orderBy?: WorkflowRecordOrderBy;
orderDirection: SQLiteDirection;

View File

@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
FieldIdentifierWithValue,
FieldIdentifierWithInstance,
WorkflowMode,
WorkflowsState as WorkflowState,
} from 'features/nodes/store/types';
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
};
const initialWorkflowState: WorkflowState = {
_version: 1,
_version: 2,
isTouched: false,
mode: 'view',
originalExposedFieldValues: [],
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
const { id, isOpen } = action.payload;
state.categorySections[id] = isOpen;
},
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
state.exposedFields = uniqBy(
state.exposedFields.concat(omit(action.payload, 'value')),
(field) => `${field.nodeId}-${field.fieldName}`
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
workflowExtra.exposedFields.forEach((field) => {
const node = nodes.find((n) => n.id === field.nodeId);
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
const node = nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const input = node.data.inputs[field.fieldName];
const field = node.data.inputs[fieldName];
if (!input) {
if (!field) {
return;
}
const originalExposedFieldValue = {
nodeId: field.nodeId,
fieldName: field.fieldName,
value: input.value,
nodeId,
fieldName,
field,
};
originalExposedFieldValues.push(originalExposedFieldValue);
});
@@ -243,6 +243,9 @@ const migrateWorkflowState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
return deepClone(initialWorkflowState);
}
return state;
};

View File

@@ -1,3 +1,7 @@
import {
zFloatRangeStartStepCountGenerator,
zIntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { z } from 'zod';
@@ -282,6 +286,8 @@ export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemp
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldCollectionValue,
generator: zIntegerRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
@@ -343,9 +349,12 @@ export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate
// #endregion
// #region FloatField Collection
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldCollectionValue,
generator: zFloatRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
@@ -373,7 +382,6 @@ const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatCollectionFieldType,
});
export type FloatFieldCollectionValue = z.infer<typeof zFloatFieldCollectionValue>;
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);

View File

@@ -1,13 +1,17 @@
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
FloatFieldCollectionValue,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
IntegerFieldCollectionValue,
StringFieldCollectionInputTemplate,
StringFieldCollectionValue,
} from 'features/nodes/types/field';
import {
floatRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { t } from 'i18next';
export const validateImageFieldCollectionValue = (
@@ -67,12 +71,31 @@ export const validateStringFieldCollectionValue = (
return reasons;
};
export const resolveNumberFieldCollectionValue = (
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
): number[] | undefined => {
if (field.generator?.type === 'float-range-generator-start-step-count') {
return floatRangeStartStepCountGenerator(field.generator);
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
return integerRangeStartStepCountGenerator(field.generator);
} else {
return field.value;
}
};
export const validateNumberFieldCollectionValue = (
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
const value = resolveNumberFieldCollectionValue(field);
if (value === undefined) {
reasons.push(t('parameters.invoke.collectionEmpty'));
return reasons;
}
const count = value.length;
// Image collections may have min or max items to validate

View File

@@ -0,0 +1,29 @@
import { z } from 'zod';
export const zFloatRangeStartStepCountGenerator = z.object({
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
start: z.number().default(0),
step: z.number().default(1),
count: z.number().int().default(10),
});
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
zFloatRangeStartStepCountGenerator.parse({});
export const zIntegerRangeStartStepCountGenerator = z.object({
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
start: z.number().int().default(0),
step: z.number().int().default(1),
count: z.number().int().default(10),
});
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
zIntegerRangeStartStepCountGenerator.parse({});

View File

@@ -1,5 +1,7 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { negate, omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
@@ -25,7 +27,11 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
} else {
inputsAccumulator[name] = input.value;
}
return inputsAccumulator;
},

View File

@@ -30,6 +30,7 @@ import {
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
resolveNumberFieldCollectionValue,
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
@@ -176,14 +177,14 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
isIntegerFieldCollectionInputInstance(field) &&
isIntegerFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isFloatFieldCollectionInputInstance(field) &&
isFloatFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
}
});
@@ -555,10 +556,10 @@ const getBatchCollectionSize = (batchNode: InvocationNode) => {
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
return batchNode.data.inputs.floats.value?.length ?? 0;
return resolveNumberFieldCollectionValue(batchNode.data.inputs.floats)?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
return batchNode.data.inputs.integers.value?.length ?? 0;
return resolveNumberFieldCollectionValue(batchNode.data.inputs.integers)?.length ?? 0;
}
return 0;
};