mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
refactor(ui): persistent workflow field value generators
Previously, workflow generators existed on a layer above the workflow and were ephemeral. Generators could be run and the result saved to the workflow. There was no way to have the generator and its settings to be an inherent part of the workflow. When you refresh the page or load a workflow, the generator settings are reset. For example, a number collection field's value is a list of numbers. When you use a range generator for that field, the generated list of numbers is written to the workflow. When you refresh the page or load the workflow later, all you have is the list of numbers. This change makes generators a part of the workflow itself. In other words, the a field's generator settings are persisted to the workflow alongside the field, and eligible fields can be thought of as having a generator _as_ their state. For example, consider a number collection. If the field has a generator enabled, the generator settings are stored in the workflow directly, in that field's state. When we need to access the field's value, if it has a generator, we run the generator. If there is no generator, we get the directly-entered value. This enables an important use-case, where the workflow editor can set up a good baseline generator and save it to the workflow. Then the workflow user loads the workflow, and they just see the generator settings, importantly with the default values set by the editor. They never need to see a big list of values. - Add generator persistence to number collection field values. - Update all logic that references field values to use "resolved" field values if the field is a number collection field. This includes validation logic. - Rework the generator UI. Generators are now part of each field, not a separate modal. You can enable the generator, reset its values, commit them and then edit them. Or, disable the generator to manually edit the values. - Support locking the linear view mode. If the workflow editor locks the field, the linear view will be slimmed down, showing only the generator fields. - Rework how the "reset to default value" functionality works with exposed fields to also work with generators. **Unfortunately, this did require some changes to redux state that I cannot easily handle in a redux state migration. As a result, on the first run after updating Invoke, their workflow editor state will be erased.**
This commit is contained in:
@@ -855,7 +855,13 @@
|
||||
},
|
||||
"nodes": {
|
||||
"noBatchGroup": "no group",
|
||||
"generator": "Generator",
|
||||
"generatedValues": "Generated Values",
|
||||
"commitValues": "Commit Values",
|
||||
"addValue": "Add Value",
|
||||
"addNode": "Add Node",
|
||||
"lockLinearView": "Lock Linear View",
|
||||
"unlockLinearView": "Unlock Linear View",
|
||||
"addNodeToolTip": "Add Node (Shift+A, Space)",
|
||||
"addLinearView": "Add to Linear View",
|
||||
"animatedEdges": "Animated Edges",
|
||||
@@ -994,11 +1000,7 @@
|
||||
"imageAccessError": "Unable to find image {{image_name}}, resetting to default",
|
||||
"boardAccessError": "Unable to find board {{board_id}}, resetting to default",
|
||||
"modelAccessError": "Unable to find model {{key}}, resetting to default",
|
||||
"saveToGallery": "Save To Gallery",
|
||||
"addItem": "Add Item",
|
||||
"generateValues": "Generate Values",
|
||||
"floatRangeGenerator": "Float Range Generator",
|
||||
"integerRangeGenerator": "Integer Range Generator"
|
||||
"saveToGallery": "Save To Gallery"
|
||||
},
|
||||
"parameters": {
|
||||
"aspect": "Aspect",
|
||||
|
||||
@@ -22,8 +22,6 @@ import { DynamicPromptsModal } from 'features/dynamicPrompts/components/DynamicP
|
||||
import DeleteBoardModal from 'features/gallery/components/Boards/DeleteBoardModal';
|
||||
import { ImageContextMenu } from 'features/gallery/components/ImageContextMenu/ImageContextMenu';
|
||||
import { useStarterModelsToast } from 'features/modelManagerV2/hooks/useStarterModelsToast';
|
||||
import { FloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
|
||||
import { IntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
|
||||
import { ShareWorkflowModal } from 'features/nodes/components/sidePanel/WorkflowListMenu/ShareWorkflowModal';
|
||||
import { ClearQueueConfirmationsAlertDialog } from 'features/queue/components/ClearQueueConfirmationAlertDialog';
|
||||
import { DeleteStylePresetDialog } from 'features/stylePresets/components/DeleteStylePresetDialog';
|
||||
@@ -112,8 +110,6 @@ const App = ({ config = DEFAULT_CONFIG, studioInitAction }: Props) => {
|
||||
<ImageContextMenu />
|
||||
<FullscreenDropzone />
|
||||
<VideosModal />
|
||||
<FloatRangeGeneratorModal />
|
||||
<IntegerRangeGeneratorModal />
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
@@ -140,10 +141,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(integers);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, integers.value);
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, integers.value);
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,10 +165,11 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(floats);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, floats.value);
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, floats.value);
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
serializableCheck: false,
|
||||
immutableCheck: false,
|
||||
// serializableCheck: import.meta.env.MODE === 'development',
|
||||
// immutableCheck: import.meta.env.MODE === 'development',
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
IconButton,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { round } from 'lodash-es';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type FloatRangeGeneratorModalState = {
|
||||
isOpen: boolean;
|
||||
onSave: (values: number[]) => void;
|
||||
};
|
||||
|
||||
const $floatRangeGeneratorModal = atom<FloatRangeGeneratorModalState>({
|
||||
isOpen: false,
|
||||
onSave: () => {},
|
||||
});
|
||||
|
||||
export const openFloatRangeGeneratorModal = (onSave: (values: number[]) => void) => {
|
||||
$floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: true, onSave });
|
||||
};
|
||||
|
||||
const onClose = () => {
|
||||
$floatRangeGeneratorModal.set({ ...$floatRangeGeneratorModal.get(), isOpen: false });
|
||||
};
|
||||
|
||||
export const FloatRangeGeneratorModal = memo(() => {
|
||||
const { isOpen, onSave } = useStore($floatRangeGeneratorModal);
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [start, setStart] = useState(0);
|
||||
const [step, setStep] = useState(1);
|
||||
const [count, setCount] = useState(1);
|
||||
|
||||
const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
setStart(0);
|
||||
setStep(1);
|
||||
setCount(1);
|
||||
}, []);
|
||||
|
||||
const onClickSave = useCallback(() => {
|
||||
onSave(values);
|
||||
onClose();
|
||||
}, [onSave, values]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t('nodes.floatRangeGenerator')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody minH={200}>
|
||||
<Flex gap={4} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={start} onChange={setStart} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={count} onChange={setCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={step} onChange={setStep} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
|
||||
</Flex>
|
||||
<Flex w="full" h="auto" flexDir="column" gap={2} pt={4}>
|
||||
<FormLabel>{t('common.values')}</FormLabel>
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base">
|
||||
<Text fontFamily="monospace" fontSize="md" color="base.300">
|
||||
{values.map((val) => round(val, 2)).join(', ')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
<ModalFooter gap={2}>
|
||||
<Button onClick={onClose} variant="ghost">
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button onClick={onClickSave}>{t('common.save')}</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
FloatRangeGeneratorModal.displayName = 'FloatRangeGeneratorModal';
|
||||
@@ -1,103 +0,0 @@
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
IconButton,
|
||||
Modal,
|
||||
ModalBody,
|
||||
ModalCloseButton,
|
||||
ModalContent,
|
||||
ModalFooter,
|
||||
ModalHeader,
|
||||
ModalOverlay,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { atom } from 'nanostores';
|
||||
import { memo, useCallback, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type IntegerRangeGeneratorModalState = {
|
||||
isOpen: boolean;
|
||||
onSave: (values: number[]) => void;
|
||||
};
|
||||
|
||||
const $integerRangeGeneratorModal = atom<IntegerRangeGeneratorModalState>({
|
||||
isOpen: false,
|
||||
onSave: () => {},
|
||||
});
|
||||
|
||||
export const openIntegerRangeGeneratorModal = (onSave: (values: number[]) => void) => {
|
||||
$integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: true, onSave });
|
||||
};
|
||||
|
||||
const onClose = () => {
|
||||
$integerRangeGeneratorModal.set({ ...$integerRangeGeneratorModal.get(), isOpen: false });
|
||||
};
|
||||
|
||||
export const IntegerRangeGeneratorModal = memo(() => {
|
||||
const { isOpen, onSave } = useStore($integerRangeGeneratorModal);
|
||||
const { t } = useTranslation();
|
||||
const [start, setStart] = useState(0);
|
||||
const [step, setStep] = useState(1);
|
||||
const [count, setCount] = useState(1);
|
||||
|
||||
const values = useMemo(() => Array.from({ length: count }, (_, i) => start + i * step), [start, step, count]);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
setStart(0);
|
||||
setStep(1);
|
||||
setCount(1);
|
||||
}, []);
|
||||
|
||||
const onClickSave = useCallback(() => {
|
||||
onSave(values);
|
||||
onClose();
|
||||
}, [onSave, values]);
|
||||
|
||||
return (
|
||||
<Modal isOpen={isOpen} onClose={onClose} isCentered>
|
||||
<ModalOverlay />
|
||||
<ModalContent>
|
||||
<ModalHeader>{t('nodes.integerRangeGenerator')}</ModalHeader>
|
||||
<ModalCloseButton />
|
||||
<ModalBody minH={200}>
|
||||
<Flex gap={4} alignItems="flex-end">
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={start} onChange={setStart} min={-Infinity} max={Infinity} step={1} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={count} onChange={setCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical">
|
||||
<FormLabel>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={step} onChange={setStep} min={-Infinity} max={Infinity} step={1} />
|
||||
</FormControl>
|
||||
<IconButton aria-label="Reset" icon={<PiArrowCounterClockwiseBold />} onClick={onReset} variant="ghost" />
|
||||
</Flex>
|
||||
<Flex w="full" h="auto" flexDir="column" gap={2} pt={4}>
|
||||
<FormLabel>{t('common.values')}</FormLabel>
|
||||
<Flex w="full" h="full" p={2} borderWidth={1} borderRadius="base">
|
||||
<Text fontFamily="monospace" fontSize="md" color="base.300">
|
||||
{values.join(', ')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</ModalBody>
|
||||
<ModalFooter gap={2}>
|
||||
<Button onClick={onClose} variant="ghost">
|
||||
{t('common.cancel')}
|
||||
</Button>
|
||||
<Button onClick={onClickSave}>{t('common.save')}</Button>
|
||||
</ModalFooter>
|
||||
</ModalContent>
|
||||
</Modal>
|
||||
);
|
||||
});
|
||||
|
||||
IntegerRangeGeneratorModal.displayName = 'IntegerRangeGeneratorModal';
|
||||
@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
{fieldNames.connectionFields.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
|
||||
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
</GridItem>
|
||||
))}
|
||||
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
{fieldNames.missingFields.map((fieldName) => (
|
||||
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
</Flex>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import {
|
||||
selectWorkflowSlice,
|
||||
workflowExposedFieldAdded,
|
||||
@@ -19,7 +19,7 @@ type Props = {
|
||||
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const selectIsExposed = useMemo(
|
||||
() =>
|
||||
createSelector(selectWorkflowSlice, (workflow) => {
|
||||
@@ -31,8 +31,11 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const isExposed = useAppSelector(selectIsExposed);
|
||||
|
||||
const handleExposeField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
|
||||
}, [dispatch, fieldName, nodeId, value]);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
|
||||
}, [dispatch, field, fieldName, nodeId]);
|
||||
|
||||
const handleUnexposeField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
||||
|
||||
@@ -14,9 +14,10 @@ import { InputFieldWrapper } from './InputFieldWrapper';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
}
|
||||
|
||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
|
||||
@@ -69,12 +70,12 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
px={2}
|
||||
>
|
||||
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Flex gap={1}>
|
||||
<Flex gap={1} alignItems="center">
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
|
||||
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
|
||||
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
|
||||
</Flex>
|
||||
</FormControl>
|
||||
|
||||
|
||||
@@ -99,109 +99,285 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<StringFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<StringFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
|
||||
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<BooleanFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <NumberFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<EnumFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ImageFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ImageFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<BoardFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ModelIdentifierFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<RefinerModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<VAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<T5EncoderModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPLEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPGEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ControlLoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<FluxVAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<LoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ControlNetModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<IPAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<T2IAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -213,28 +389,64 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ColorFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<FluxMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SD3MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SDXLMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SchedulerFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldTemplate) {
|
||||
|
||||
@@ -97,7 +97,11 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
|
||||
icon={<PiTrashSimpleBold />}
|
||||
/>
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
isLinearView={true}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<DndListDropIndicator dndState={dndListState} />
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
type FloatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type FloatRangeGeneratorProps = {
|
||||
state: FloatRangeStartStepCountGenerator;
|
||||
onChange: (state: FloatRangeStartStepCountGenerator) => void;
|
||||
};
|
||||
|
||||
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getDefaultFloatRangeStartStepCountGenerator());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={1} alignItems="flex-end" p={1}>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<IconButton
|
||||
onClick={onReset}
|
||||
aria-label={t('common.reset')}
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FloatRangeGenerator.displayName = 'FloatRangeGenerator';
|
||||
@@ -1,33 +1,45 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
ButtonGroup,
|
||||
CompositeNumberInput,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
IconButton,
|
||||
Switch,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { openFloatRangeGeneratorModal } from 'features/nodes/components/FloatRangeGeneratorModal';
|
||||
import { openIntegerRangeGeneratorModal } from 'features/nodes/components/IntegerRangeGeneratorModal';
|
||||
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldNumberCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldNumberCollectionValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isNil } from 'lodash-es';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
@@ -47,7 +59,7 @@ export const NumberFieldCollectionInputComponent = memo(
|
||||
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
|
||||
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate } = props;
|
||||
const { nodeId, field, fieldTemplate, isLinearView } = props;
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -77,17 +89,6 @@ export const NumberFieldCollectionInputComponent = memo(
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
const onOpenGenerator = useCallback(() => {
|
||||
const onSave = (values: number[]) => {
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: values }));
|
||||
};
|
||||
if (isIntegerField) {
|
||||
openIntegerRangeGeneratorModal(onSave);
|
||||
} else {
|
||||
openFloatRangeGeneratorModal(onSave);
|
||||
}
|
||||
}, [field.name, isIntegerField, nodeId, store]);
|
||||
|
||||
const min = useMemo(() => {
|
||||
let min = -NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.minimum)) {
|
||||
@@ -124,6 +125,32 @@ export const NumberFieldCollectionInputComponent = memo(
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const toggleGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onChangeGenerator = useCallback(
|
||||
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
|
||||
},
|
||||
[field.name, nodeId, store]
|
||||
);
|
||||
|
||||
const onCommitGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onToggleLockLinearView = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const valuesAsString = useMemo(() => {
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(field);
|
||||
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
|
||||
}, [field]);
|
||||
|
||||
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
@@ -140,17 +167,48 @@ export const NumberFieldCollectionInputComponent = memo(
|
||||
flexDir="column"
|
||||
gap={1}
|
||||
>
|
||||
<ButtonGroup isAttached={false} size="sm" w="full" gap={1}>
|
||||
<Button onClick={onAddNumber} variant="ghost" w="50%">
|
||||
{t('nodes.addItem')}
|
||||
</Button>
|
||||
<Button onClick={onOpenGenerator} variant="ghost" w="50%">
|
||||
{t('nodes.generateValues')}
|
||||
</Button>
|
||||
</ButtonGroup>
|
||||
{field.value && field.value.length > 0 && (
|
||||
<Flex w="full" gap={2}>
|
||||
{!field.generator && (
|
||||
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
|
||||
{t('nodes.addValue')}
|
||||
</Button>
|
||||
)}
|
||||
{field.generator && isLockedOnLinearView && (
|
||||
<Button
|
||||
tooltip={
|
||||
<Flex p={1} flexDir="column">
|
||||
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
|
||||
<Text fontFamily="monospace">{valuesAsString}</Text>
|
||||
</Flex>
|
||||
}
|
||||
onClick={onCommitGenerator}
|
||||
variant="ghost"
|
||||
flexGrow={1}
|
||||
size="sm"
|
||||
>
|
||||
{t('nodes.commitValues')}
|
||||
</Button>
|
||||
)}
|
||||
{isLockedOnLinearView && (
|
||||
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
|
||||
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
|
||||
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
|
||||
</FormControl>
|
||||
)}
|
||||
{!isLinearView && (
|
||||
<IconButton
|
||||
onClick={onToggleLockLinearView}
|
||||
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{!field.generator && field.value && field.value.length > 0 && (
|
||||
<>
|
||||
<Divider />
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
@@ -176,6 +234,12 @@ export const NumberFieldCollectionInputComponent = memo(
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
|
||||
<>
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
|
||||
nodeId: string;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -63,13 +63,13 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field.value, template).length > 0) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
@@ -10,19 +11,38 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectOriginalExposedFieldValues = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectWorkflowSlice,
|
||||
(workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
|
||||
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
|
||||
),
|
||||
[nodeId, fieldName]
|
||||
);
|
||||
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
|
||||
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => {
|
||||
if (!field) {
|
||||
// Field is not found, so it is not changed
|
||||
return false;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
return !isEqual(field.value, exposedField?.field.value);
|
||||
}, [field, exposedField]);
|
||||
const onReset = useCallback(() => {
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
|
||||
}, [dispatch, fieldName, nodeId, originalValue]);
|
||||
if (!exposedField) {
|
||||
return;
|
||||
}
|
||||
const { value } = exposedField.field;
|
||||
const generator =
|
||||
isIntegerFieldCollectionInputInstance(exposedField.field) ||
|
||||
isFloatFieldCollectionInputInstance(exposedField.field)
|
||||
? exposedField.field.generator
|
||||
: undefined;
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
|
||||
}, [dispatch, fieldName, nodeId, exposedField]);
|
||||
|
||||
return { originalValue, isValueChanged, onReset };
|
||||
return { originalValue: exposedField, isValueChanged, onReset };
|
||||
};
|
||||
|
||||
@@ -36,6 +36,8 @@ import type {
|
||||
VAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
@@ -66,6 +68,16 @@ import {
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
getDefaultIntegerRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom, computed } from 'nanostores';
|
||||
@@ -83,11 +95,22 @@ const initialNodesState: NodesState = {
|
||||
edges: [],
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
}>;
|
||||
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
|
||||
{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
} & U
|
||||
>;
|
||||
|
||||
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data?.inputs[fieldName];
|
||||
};
|
||||
|
||||
const fieldValueReducer = <T extends FieldValue>(
|
||||
state: NodesState,
|
||||
@@ -95,17 +118,24 @@ const fieldValueReducer = <T extends FieldValue>(
|
||||
schema: z.ZodTypeAny
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const input = node.data?.inputs[fieldName];
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = schema.safeParse(value);
|
||||
if (!input || nodeIndex < 0 || !result.success) {
|
||||
if (!field || !result.success) {
|
||||
return;
|
||||
}
|
||||
input.value = result.data;
|
||||
field.value = result.data;
|
||||
// Special handling if the field value is being reset
|
||||
if (result.data === undefined) {
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultFloatRangeStartStepCountGenerator();
|
||||
}
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const nodesSlice = createSlice({
|
||||
@@ -310,8 +340,31 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStatefulFieldValue);
|
||||
fieldValueReset: (
|
||||
state,
|
||||
action: FieldValueAction<
|
||||
StatefulFieldValue,
|
||||
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
|
||||
>
|
||||
) => {
|
||||
const { nodeId, fieldName, value, generator } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = zStatefulFieldValue.safeParse(value);
|
||||
|
||||
if (!field || !result.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
field.value = result.data;
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
|
||||
field.generator = generator;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generator?.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generator;
|
||||
}
|
||||
},
|
||||
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldValue);
|
||||
@@ -325,6 +378,85 @@ export const nodesSlice = createSlice({
|
||||
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
|
||||
},
|
||||
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorStateChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, generatorState } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = floatRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = integerRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionLockLinearViewToggled: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; fieldName: string }>
|
||||
) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
|
||||
return;
|
||||
}
|
||||
field.lockLinearView = !field.lockLinearView;
|
||||
},
|
||||
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBooleanFieldValue);
|
||||
},
|
||||
@@ -447,6 +579,10 @@ export const {
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type {
|
||||
FieldIdentifier,
|
||||
FieldInputInstance,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
@@ -31,15 +31,15 @@ export type NodesState = {
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
export type FieldIdentifierWithValue = FieldIdentifier & {
|
||||
value: StatefulFieldValue;
|
||||
export type FieldIdentifierWithInstance = FieldIdentifier & {
|
||||
field: FieldInputInstance;
|
||||
};
|
||||
|
||||
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
|
||||
_version: 1;
|
||||
_version: 2;
|
||||
isTouched: boolean;
|
||||
mode: WorkflowMode;
|
||||
originalExposedFieldValues: FieldIdentifierWithValue[];
|
||||
originalExposedFieldValues: FieldIdentifierWithInstance[];
|
||||
searchTerm: string;
|
||||
orderBy?: WorkflowRecordOrderBy;
|
||||
orderDirection: SQLiteDirection;
|
||||
|
||||
@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FieldIdentifierWithValue,
|
||||
FieldIdentifierWithInstance,
|
||||
WorkflowMode,
|
||||
WorkflowsState as WorkflowState,
|
||||
} from 'features/nodes/store/types';
|
||||
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
};
|
||||
|
||||
const initialWorkflowState: WorkflowState = {
|
||||
_version: 1,
|
||||
_version: 2,
|
||||
isTouched: false,
|
||||
mode: 'view',
|
||||
originalExposedFieldValues: [],
|
||||
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
|
||||
const { id, isOpen } = action.payload;
|
||||
state.categorySections[id] = isOpen;
|
||||
},
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
|
||||
state.exposedFields = uniqBy(
|
||||
state.exposedFields.concat(omit(action.payload, 'value')),
|
||||
(field) => `${field.nodeId}-${field.fieldName}`
|
||||
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
|
||||
|
||||
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
|
||||
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
|
||||
|
||||
workflowExtra.exposedFields.forEach((field) => {
|
||||
const node = nodes.find((n) => n.id === field.nodeId);
|
||||
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const input = node.data.inputs[field.fieldName];
|
||||
const field = node.data.inputs[fieldName];
|
||||
|
||||
if (!input) {
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
|
||||
const originalExposedFieldValue = {
|
||||
nodeId: field.nodeId,
|
||||
fieldName: field.fieldName,
|
||||
value: input.value,
|
||||
nodeId,
|
||||
fieldName,
|
||||
field,
|
||||
};
|
||||
originalExposedFieldValues.push(originalExposedFieldValue);
|
||||
});
|
||||
@@ -243,6 +243,9 @@ const migrateWorkflowState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
return deepClone(initialWorkflowState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import {
|
||||
zFloatRangeStartStepCountGenerator,
|
||||
zIntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -282,6 +286,8 @@ export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemp
|
||||
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
|
||||
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerFieldCollectionValue,
|
||||
generator: zIntegerRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
@@ -343,9 +349,12 @@ export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate
|
||||
// #endregion
|
||||
|
||||
// #region FloatField Collection
|
||||
|
||||
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
|
||||
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldCollectionValue,
|
||||
generator: zFloatRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
@@ -373,7 +382,6 @@ const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
});
|
||||
export type FloatFieldCollectionValue = z.infer<typeof zFloatFieldCollectionValue>;
|
||||
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
|
||||
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
|
||||
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldCollectionValue,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionValue,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionValue,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldCollectionValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const validateImageFieldCollectionValue = (
|
||||
@@ -67,12 +71,31 @@ export const validateStringFieldCollectionValue = (
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const resolveNumberFieldCollectionValue = (
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
|
||||
): number[] | undefined => {
|
||||
if (field.generator?.type === 'float-range-generator-start-step-count') {
|
||||
return floatRangeStartStepCountGenerator(field.generator);
|
||||
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
|
||||
return integerRangeStartStepCountGenerator(field.generator);
|
||||
} else {
|
||||
return field.value;
|
||||
}
|
||||
};
|
||||
|
||||
export const validateNumberFieldCollectionValue = (
|
||||
value: NonNullable<IntegerFieldCollectionValue> | NonNullable<FloatFieldCollectionValue>,
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
|
||||
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
|
||||
const value = resolveNumberFieldCollectionValue(field);
|
||||
|
||||
if (value === undefined) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
return reasons;
|
||||
}
|
||||
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
|
||||
29
invokeai/frontend/web/src/features/nodes/types/generators.ts
Normal file
29
invokeai/frontend/web/src/features/nodes/types/generators.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
export const zFloatRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
|
||||
start: z.number().default(0),
|
||||
step: z.number().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
|
||||
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
|
||||
zFloatRangeStartStepCountGenerator.parse({});
|
||||
|
||||
export const zIntegerRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
|
||||
start: z.number().int().default(0),
|
||||
step: z.number().int().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
|
||||
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
|
||||
zIntegerRangeStartStepCountGenerator.parse({});
|
||||
@@ -1,5 +1,7 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { negate, omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
@@ -25,7 +27,11 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const transformedInputs = reduce(
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
inputsAccumulator[name] = input.value;
|
||||
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
|
||||
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
|
||||
} else {
|
||||
inputsAccumulator[name] = input.value;
|
||||
}
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
|
||||
@@ -30,6 +30,7 @@ import {
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
resolveNumberFieldCollectionValue,
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
@@ -176,14 +177,14 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
isIntegerFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
isFloatFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field.value, fieldTemplate);
|
||||
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
}
|
||||
});
|
||||
@@ -555,10 +556,10 @@ const getBatchCollectionSize = (batchNode: InvocationNode) => {
|
||||
return batchNode.data.inputs.strings.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
return batchNode.data.inputs.floats.value?.length ?? 0;
|
||||
return resolveNumberFieldCollectionValue(batchNode.data.inputs.floats)?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
return batchNode.data.inputs.integers.value?.length ?? 0;
|
||||
return resolveNumberFieldCollectionValue(batchNode.data.inputs.integers)?.length ?? 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user