Compare commits

...

65 Commits

Author SHA1 Message Date
psychedelicious
6df3e9f960 refactor(ui): persistent workflow field value generators
Previously, workflow generators existed on a layer above the workflow and were ephemeral. Generators could be run and the result saved to the workflow.

There was no way to have the generator and its settings to be an inherent part of the workflow. When you refresh the page or load a workflow, the generator settings are reset.

For example, a number collection field's value is a list of numbers. When you use a range generator for that field, the generated list of numbers is written to the workflow. When you refresh the page or load the workflow later, all you have is the list of numbers.

This change makes generators a part of the workflow itself. In other words, the a field's generator settings are persisted to the workflow alongside the field, and eligible fields can be thought of as having a generator _as_ their state.

For example, consider a number collection. If the field has a generator enabled, the generator settings are stored in the workflow directly, in that field's state. When we need to access the field's value, if it has a generator, we run the generator. If there is no generator, we get the directly-entered value.

This enables an important use-case, where the workflow editor can set up a good baseline generator and save it to the workflow.

Then the workflow user loads the workflow, and they just see the generator settings, importantly with the default values set by the editor. They never need to see a big list of values.

- Add generator persistence to number collection field values.
- Update all logic that references field values to use "resolved" field values if the field is a number collection field. This includes validation logic.
- Rework the generator UI. Generators are now part of each field, not a separate modal. You can enable the generator, reset its values, commit them and then edit them. Or, disable the generator to manually edit the values.
- Support locking the linear view mode. If the workflow editor locks the field, the linear view will be slimmed down, showing only the generator fields.
- Rework how the "reset to default value" functionality works with exposed fields to also work with generators. **Unfortunately, this did require some changes to redux state that I cannot easily handle in a redux state migration. As a result, on the first run after updating Invoke, their workflow editor state will be erased.**
2025-01-16 10:22:26 +11:00
psychedelicious
42d59b961e tidy(nodes): code dedupe for batch node init errors 2025-01-16 10:22:26 +11:00
psychedelicious
e8eac3d259 tidy(nodes): move batch nodes to own file 2025-01-16 10:22:26 +11:00
psychedelicious
74a4197398 chore(ui): knip 2025-01-16 09:48:10 +11:00
psychedelicious
ba1701d734 tweak(ui): error verbiage for collection size mismatch 2025-01-16 09:48:10 +11:00
psychedelicious
10a77d22ef fix(ui): invoke tooltip for invalid/empty batches 2025-01-16 09:48:10 +11:00
psychedelicious
54dbc16cc9 chore(ui): lint 2025-01-16 09:48:09 +11:00
psychedelicious
a95cc979a9 fix(ui): unclosed JSX tag 2025-01-16 09:48:09 +11:00
psychedelicious
b23abba8a6 feat(ui): validate all batch nodes have connection 2025-01-16 09:48:09 +11:00
psychedelicious
d1ded55d8d feat(ui): show batch group in node title 2025-01-16 09:48:09 +11:00
psychedelicious
d81cbd0a14 fix(ui): handle batch group ids of "None" correctly 2025-01-16 09:48:09 +11:00
psychedelicious
fa21f0887d tweak(ui): enum field selects have size="sm" 2025-01-16 09:48:09 +11:00
psychedelicious
4ecb4e8929 chore(ui): typegen 2025-01-16 09:48:09 +11:00
psychedelicious
ede13f7882 feat(nodes): add title for batch_group_id field 2025-01-16 09:48:09 +11:00
psychedelicious
77472a2f0c tweak(ui): node editor layout padding 2025-01-16 09:48:09 +11:00
psychedelicious
029c1fb8d9 chore(ui): typegen 2025-01-16 09:48:09 +11:00
psychedelicious
100b151f84 feat(nodes): batch_group_id is a literal of options 2025-01-16 09:48:09 +11:00
psychedelicious
a7326e3ad4 feat(ui): rename "link_id" -> "batch_group_id" 2025-01-16 09:48:09 +11:00
psychedelicious
775bb276b2 chore(ui): typegen 2025-01-16 09:48:09 +11:00
psychedelicious
977c2668e8 feat(nodes): rename "link_id" -> "batch_group_id" 2025-01-16 09:48:09 +11:00
psychedelicious
515ff485fc feat(ui): add zipped batch collection size validation 2025-01-16 09:48:09 +11:00
psychedelicious
87cbc8ad45 fix(ui): allow batch nodes without link id (i.e. product batch nodes) to have mismatched collection sizes 2025-01-16 09:48:09 +11:00
psychedelicious
2c250c29e8 feat(ui): support zipped batch nodes 2025-01-16 09:48:09 +11:00
psychedelicious
9db69782e9 chore(ui): typegen 2025-01-16 09:48:09 +11:00
psychedelicious
2499cf0c52 feat(nodes): add link_id field to batch nodes
This is used to link batch nodes into zipped batch data collections.
2025-01-16 09:48:09 +11:00
psychedelicious
febc9615fc chore(ui): typegen 2025-01-16 09:48:09 +11:00
psychedelicious
b6aae16471 chore(ui): lint 2025-01-16 09:46:46 +11:00
psychedelicious
f27250a4a7 chore(ui): typegen 2025-01-16 09:46:46 +11:00
psychedelicious
0b95319dfa tweak(ui): number collection styling 2025-01-16 09:46:46 +11:00
psychedelicious
388580efb3 feat(ui): string collection batch items are input not textarea 2025-01-16 09:46:45 +11:00
psychedelicious
9ac30bd2a5 fix(ui): translation key 2025-01-16 09:46:45 +11:00
psychedelicious
c3d2eb5426 feat(ui): add number range generators 2025-01-16 09:46:45 +11:00
psychedelicious
789eb1fff5 Revert "feat(ui): rough out number generators for number collection fields"
This reverts commit 41cc6f1f96bca2a51727f21bd727ca48eab669bc.
2025-01-16 09:46:45 +11:00
psychedelicious
fb5af7a4b7 Revert "feat(ui): number collection generator supports floats"
This reverts commit 9da3339b513de9575ffbf6ce880b3097217b199d.
2025-01-16 09:46:45 +11:00
psychedelicious
eea29863a0 Revert "feat(ui): more batch generator stuff"
This reverts commit 111a29c7b4fc6b5062a0a37ce704a6508ff58dd8.
2025-01-16 09:46:45 +11:00
psychedelicious
fab0af4d77 feat(ui): more batch generator stuff 2025-01-16 09:46:45 +11:00
psychedelicious
420c1d2874 tidy(ui): abstract out batch detection logic 2025-01-16 09:46:45 +11:00
psychedelicious
d980a87e25 feat(nodes): add default value for batch nodes 2025-01-16 09:46:45 +11:00
psychedelicious
d02a8a9b62 feat(ui): number collection generator supports floats 2025-01-16 09:46:45 +11:00
psychedelicious
084228c162 fix(ui): do not set number collection field to undefined when removing last item 2025-01-16 09:46:45 +11:00
psychedelicious
366ac86cbe fix(ui): filter out batch nodes when checking readiness on workflows tab 2025-01-16 09:46:45 +11:00
psychedelicious
55423ad1d6 perf(ui): memoize selector in workflows 2025-01-16 09:46:45 +11:00
psychedelicious
11f17e3ea0 feat(ui): rough out number generators for number collection fields 2025-01-16 09:46:45 +11:00
psychedelicious
2d145871d9 fix(nodes): allow batch datum items to mix ints and floats
Unfortunately we cannot do strict floats or ints.

The batch data models don't specify the value types, it instead relies on pydantic parsing. JSON doesn't differentiate between float and int, so a float `1.0` gets parsed as `1` in python.

As a result, we _must_ accept mixed floats and ints for BatchDatum.items.

Tests and validation updated to handle this.

Maybe we should update the BatchDatum model to have a `type` field? Then we could parse as float or int, depending on the inputs...
2025-01-16 09:46:45 +11:00
psychedelicious
aace8366d6 fix(ui): float batch data creation 2025-01-16 09:46:45 +11:00
psychedelicious
77555615bc chore(ui): lint 2025-01-16 09:46:45 +11:00
psychedelicious
4d9b35e8bd tidy(ui): use zod typeguard builder util for fields 2025-01-16 09:46:45 +11:00
psychedelicious
3cecc25d6c chore(ui): typegen 2025-01-16 09:46:44 +11:00
psychedelicious
5a610dd00c feat(ui): validate number item multipleOf 2025-01-16 09:46:24 +11:00
psychedelicious
2d8443cc21 feat(ui): validate string item lengths 2025-01-16 09:46:24 +11:00
psychedelicious
1ed70bb21e feat(ui): support float batches 2025-01-16 09:46:24 +11:00
psychedelicious
e7a61c86f1 refactor(ui): abstract out helper to add batch data 2025-01-16 09:46:24 +11:00
psychedelicious
7d1f38560b fix(ui): typo 2025-01-16 09:46:24 +11:00
psychedelicious
2fcca151e7 refactor(ui): abstract out field validators 2025-01-16 09:46:24 +11:00
psychedelicious
c83caed552 feat(ui): add template validation for integer collection items 2025-01-16 09:46:23 +11:00
psychedelicious
a1687fafdd feat(ui): add template validation for string collection items 2025-01-16 09:46:23 +11:00
psychedelicious
41a36f2701 feat(nodes): add float batch node 2025-01-16 09:46:23 +11:00
psychedelicious
8807248a1c feat(ui): support integer batches 2025-01-16 09:46:23 +11:00
psychedelicious
14a605c1e1 feat(nodes): add integer batch node 2025-01-16 09:46:23 +11:00
psychedelicious
edc39581ba feat(ui): support string batches 2025-01-16 09:46:23 +11:00
psychedelicious
de21cf1383 refactor(ui): streamline image field collection input logic, support multiple images w/ same name in collection 2025-01-16 09:46:23 +11:00
psychedelicious
035508e2ee tweak(ui): image field collection input component styling 2025-01-16 09:46:23 +11:00
psychedelicious
1b739b4f86 docs(ui): improved comments for image batch node special handling 2025-01-16 09:46:23 +11:00
psychedelicious
32f65937af feat(nodes): add string batch node 2025-01-16 09:46:23 +11:00
psychedelicious
6cb87a86c8 fix(ui): typo in error message for image collection fields 2025-01-16 09:46:23 +11:00
44 changed files with 2404 additions and 413 deletions

1
.nvmrc Normal file
View File

@@ -0,0 +1 @@
v22.12.0

View File

@@ -0,0 +1,118 @@
from typing import Literal
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
Classification,
invocation,
)
from invokeai.app.invocations.fields import (
ImageField,
Input,
InputField,
)
from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
BATCH_GROUP_IDS = Literal[
"None",
"Group 1",
"Group 2",
"Group 3",
"Group 4",
"Group 5",
]
class NotExecutableNodeError(Exception):
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
super().__init__(message)
pass
class BaseBatchInvocation(BaseInvocation):
batch_group_id: BATCH_GROUP_IDS = InputField(
default="None",
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
input=Input.Direct,
title="Batch Group",
)
def __init__(self):
raise NotExecutableNodeError()
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(
default=[], min_length=1, description="The images to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotExecutableNodeError()
@invocation(
"string_batch",
title="String Batch",
tags=["primitives", "string", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class StringBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
strings: list[str] = InputField(
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> StringOutput:
raise NotExecutableNodeError()
@invocation(
"integer_batch",
title="Integer Batch",
tags=["primitives", "integer", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class IntegerBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
integers: list[int] = InputField(
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> IntegerOutput:
raise NotExecutableNodeError()
@invocation(
"float_batch",
title="Float Batch",
tags=["primitives", "float", "number", "batch", "special"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class FloatBatchInvocation(BaseBatchInvocation):
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
floats: list[float] = InputField(
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
)
def invoke(self, context: InvocationContext) -> FloatOutput:
raise NotExecutableNodeError()

View File

@@ -7,7 +7,6 @@ import torch
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
@@ -539,23 +538,3 @@ class BoundingBoxInvocation(BaseInvocation):
# endregion
@invocation(
"image_batch",
title="Image Batch",
tags=["primitives", "image", "batch", "internal"],
category="primitives",
version="1.0.0",
classification=Classification.Special,
)
class ImageBatchInvocation(BaseInvocation):
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
def __init__(self):
raise NotImplementedError("This class should never be executed or instantiated directly.")
def invoke(self, context: InvocationContext) -> ImageOutput:
raise NotImplementedError("This class should never be executed or instantiated directly.")

View File

@@ -108,8 +108,16 @@ class Batch(BaseModel):
return v
for batch_data_list in v:
for datum in batch_data_list:
if not datum.items:
continue
# Special handling for numbers - they can be mixed
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
if all(isinstance(item, (int, float)) for item in datum.items):
continue
# Get the type of the first item in the list
first_item_type = type(datum.items[0]) if datum.items else None
first_item_type = type(datum.items[0])
for item in datum.items:
if type(item) is not first_item_type:
raise BatchItemsTypeError("All items in a batch must have the same type")

View File

@@ -177,7 +177,11 @@
"none": "None",
"new": "New",
"generating": "Generating",
"warnings": "Warnings"
"warnings": "Warnings",
"start": "Start",
"count": "Count",
"step": "Step",
"values": "Values"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -850,7 +854,14 @@
"defaultVAE": "Default VAE"
},
"nodes": {
"noBatchGroup": "no group",
"generator": "Generator",
"generatedValues": "Generated Values",
"commitValues": "Commit Values",
"addValue": "Add Value",
"addNode": "Add Node",
"lockLinearView": "Lock Linear View",
"unlockLinearView": "Unlock Linear View",
"addNodeToolTip": "Add Node (Shift+A, Space)",
"addLinearView": "Add to Linear View",
"animatedEdges": "Animated Edges",
@@ -1024,11 +1035,21 @@
"addingImagesTo": "Adding images to",
"invoke": "Invoke",
"missingFieldTemplate": "Missing field template",
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
"missingInputForField": "missing input",
"missingNodeTemplate": "Missing node template",
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
"collectionEmpty": "empty collection",
"invalidBatchConfiguration": "Invalid batch configuration",
"batchNodeNotConnected": "Batch node not connected: {{label}}",
"collectionTooFewItems": "too few items, minimum {{minItems}}",
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
"collectionStringTooLong": "too long, max {{maxLength}}",
"collectionStringTooShort": "too short, min {{minLength}}",
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
"noModelSelected": "No model selected",
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",

View File

@@ -2,10 +2,19 @@ import { logger } from 'app/logging/logger';
import { enqueueRequested } from 'app/store/actions';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import type { ImageField } from 'features/nodes/types/common';
import {
isFloatFieldCollectionInputInstance,
isImageFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { groupBy } from 'lodash-es';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, BatchConfig } from 'services/api/types';
@@ -33,28 +42,140 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
const data: Batch['data'] = [];
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromImageBatch) {
const batchNodes = nodes.nodes.filter(isInvocationNode).filter(isBatchNode);
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
const addProductBatchDataCollectionItem = (
edges: InvocationNodeEdge[],
items?: ImageField[] | string[] | number[]
) => {
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edges) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
productBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: images.value,
items,
});
}
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
};
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
const addZippedBatchDataCollectionItem = (
edges: InvocationNodeEdge[],
items?: ImageField[] | string[] | number[]
) => {
for (const edge of edges) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items,
});
}
};
// Grab image batch nodes for special handling
const imageBatchNodes = batchNodes.filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
// Satisfy TS
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
} else {
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
}
}
// Grab string batch nodes for special handling
const stringBatchNodes = batchNodes.filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
}
}
// Grab integer batch nodes for special handling
const integerBatchNodes = batchNodes.filter((node) => node.data.type === 'integer_batch');
for (const node of integerBatchNodes) {
// Satisfy TS
const integers = node.data.inputs['integers'];
if (!isIntegerFieldCollectionInputInstance(integers)) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
break;
}
if (!integers.value) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(integers);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
// Grab float batch nodes for special handling
const floatBatchNodes = batchNodes.filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isFloatFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
break;
}
if (!floats.value) {
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
const resolvedValue = resolveNumberFieldCollectionValue(floats);
if (batchGroupId !== 'None') {
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
} else {
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
}
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
}
}

View File

@@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
reducer: rememberedRootReducer,
middleware: (getDefaultMiddleware) =>
getDefaultMiddleware({
serializableCheck: import.meta.env.MODE === 'development',
immutableCheck: import.meta.env.MODE === 'development',
serializableCheck: false,
immutableCheck: false,
// serializableCheck: import.meta.env.MODE === 'development',
// immutableCheck: import.meta.env.MODE === 'development',
})
.concat(api.middleware)
.concat(dynamicMiddlewares)

View File

@@ -1,3 +1,4 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { getPrefixedId } from 'features/controlLayers/konva/util';
import type {
@@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
import type { BoardId } from 'features/gallery/store/types';
import {
addImagesToBoard,
addImagesToNodeImageFieldCollectionAction,
createNewCanvasEntityFromImage,
removeImagesFromBoard,
replaceCanvasEntityObjectsWithImage,
@@ -19,10 +19,14 @@ import {
setRegionalGuidanceReferenceImage,
setUpscaleInitialImage,
} from 'features/imageActions/actions';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import type { ImageDTO } from 'services/api/types';
import type { JsonObject } from 'type-fest';
const log = logger('dnd');
type RecordUnknown = Record<string | symbol, unknown>;
type DndData<
@@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
}
const { fieldIdentifier } = targetData.payload;
const imageDTOs: ImageDTO[] = [];
if (singleImageDndSource.typeGuard(sourceData)) {
imageDTOs.push(sourceData.payload.imageDTO);
} else {
imageDTOs.push(...sourceData.payload.imageDTOs);
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
if (singleImageDndSource.typeGuard(sourceData)) {
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
} else {
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
}
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
},
};
//#endregion

View File

@@ -1,4 +1,3 @@
import { logger } from 'app/logging/logger';
import type { AppDispatch, RootState } from 'app/store/store';
import { deepClone } from 'common/util/deepClone';
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
@@ -31,19 +30,15 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
import type { BoardId } from 'features/gallery/store/types';
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
import type { FieldIdentifier } from 'features/nodes/types/field';
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
import { uniqBy } from 'lodash-es';
import { imagesApi } from 'services/api/endpoints/images';
import type { ImageDTO } from 'services/api/types';
import type { Equals } from 'tsafe';
import { assert } from 'tsafe';
const log = logger('system');
export const setGlobalReferenceImage = (arg: {
imageDTO: ImageDTO;
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
@@ -77,54 +72,6 @@ export const setNodeImageFieldImage = (arg: {
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
};
export const addImagesToNodeImageFieldCollectionAction = (arg: {
imageDTOs: ImageDTO[];
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
const uniqueImages = uniqBy(images, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
imageName: string;
fieldIdentifier: FieldIdentifier;
dispatch: AppDispatch;
getState: () => RootState;
}) => {
const { imageName, fieldIdentifier, dispatch, getState } = arg;
const fieldInputInstance = selectFieldInputInstance(
selectNodesSlice(getState()),
fieldIdentifier.nodeId,
fieldIdentifier.fieldName
);
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
return;
}
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
};
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
const { imageDTO, dispatch } = arg;
dispatch(imageToCompareChanged(imageDTO));

View File

@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
{fieldNames.connectionFields.map((fieldName, i) => (
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
</GridItem>
))}
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
))}
{fieldNames.missingFields.map((fieldName) => (
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
nodeId={nodeId}
fieldName={fieldName}
>
<InputField nodeId={nodeId} fieldName={fieldName} />
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
</InvocationInputFieldCheck>
))}
</Flex>

View File

@@ -1,7 +1,7 @@
import { IconButton } from '@invoke-ai/ui-library';
import { createSelector } from '@reduxjs/toolkit';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import {
selectWorkflowSlice,
workflowExposedFieldAdded,
@@ -19,7 +19,7 @@ type Props = {
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const dispatch = useAppDispatch();
const { t } = useTranslation();
const value = useFieldValue(nodeId, fieldName);
const field = useFieldInputInstance(nodeId, fieldName);
const selectIsExposed = useMemo(
() =>
createSelector(selectWorkflowSlice, (workflow) => {
@@ -31,8 +31,11 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
const isExposed = useAppSelector(selectIsExposed);
const handleExposeField = useCallback(() => {
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
}, [dispatch, fieldName, nodeId, value]);
if (!field) {
return;
}
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
}, [dispatch, field, fieldName, nodeId]);
const handleUnexposeField = useCallback(() => {
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));

View File

@@ -14,9 +14,10 @@ import { InputFieldWrapper } from './InputFieldWrapper';
interface Props {
nodeId: string;
fieldName: string;
isLinearView: boolean;
}
const InputField = ({ nodeId, fieldName }: Props) => {
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
const [isHovered, setIsHovered] = useState(false);
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
@@ -69,12 +70,12 @@ const InputField = ({ nodeId, fieldName }: Props) => {
px={2}
>
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
<Flex gap={1}>
<Flex gap={1} alignItems="center">
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
</Flex>
</FormControl>

View File

@@ -1,5 +1,7 @@
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import {
@@ -21,6 +23,8 @@ import {
isControlNetModelFieldInputTemplate,
isEnumFieldInputInstance,
isEnumFieldInputTemplate,
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isFloatFieldInputInstance,
isFloatFieldInputTemplate,
isFluxMainModelFieldInputInstance,
@@ -31,6 +35,8 @@ import {
isImageFieldCollectionInputTemplate,
isImageFieldInputInstance,
isImageFieldInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isIntegerFieldInputInstance,
isIntegerFieldInputTemplate,
isIPAdapterModelFieldInputInstance,
@@ -51,6 +57,8 @@ import {
isSDXLRefinerModelFieldInputTemplate,
isSpandrelImageToImageModelFieldInputInstance,
isSpandrelImageToImageModelFieldInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
isStringFieldInputInstance,
isStringFieldInputTemplate,
isT2IAdapterModelFieldInputInstance,
@@ -91,96 +99,285 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
type InputFieldProps = {
nodeId: string;
fieldName: string;
isLinearView: boolean;
};
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
return (
<StringFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<StringFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BooleanFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
) {
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
return (
<NumberFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
return (
<NumberFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<EnumFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ImageFieldCollectionInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ImageFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<BoardFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ModelIdentifierFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<RefinerModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<VAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<T5EncoderModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPLEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<CLIPGEmbedModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ControlLoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<FluxVAEModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<LoRAModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ControlNetModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<IPAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<T2IAdapterModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (
@@ -192,28 +389,64 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<ColorFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<FluxMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SD3MainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SDXLMainModelFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
return (
<SchedulerFieldInputComponent
nodeId={nodeId}
field={fieldInstance}
fieldTemplate={fieldTemplate}
isLinearView={isLinearView}
/>
);
}
if (fieldTemplate) {

View File

@@ -1,6 +1,6 @@
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
const templates = useStore($templates);
const selector = useMemo(
() =>
createSelector(selectNodesSlice, (nodesSlice) => {
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
const node = selectInvocationNode(nodesSlice, nodeId);
const instance = node.data.inputs[fieldName];
const template = templates[node.data.type];

View File

@@ -97,7 +97,11 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
icon={<PiTrashSimpleBold />}
/>
</Flex>
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
<InputFieldRenderer
nodeId={fieldIdentifier.nodeId}
fieldName={fieldIdentifier.fieldName}
isLinearView={true}
/>
</Flex>
</Flex>
<DndListDropIndicator dndState={dndListState} />

View File

@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
);
return (
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
{fieldTemplate.options.map((option) => (
<option key={option} value={option}>
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}

View File

@@ -0,0 +1,65 @@
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
import {
type FloatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
type FloatRangeGeneratorProps = {
state: FloatRangeStartStepCountGenerator;
onChange: (state: FloatRangeStartStepCountGenerator) => void;
};
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
const { t } = useTranslation();
const onChangeStart = useCallback(
(start: number) => {
onChange({ ...state, start });
},
[onChange, state]
);
const onChangeStep = useCallback(
(step: number) => {
onChange({ ...state, step });
},
[onChange, state]
);
const onChangeCount = useCallback(
(count: number) => {
onChange({ ...state, count });
},
[onChange, state]
);
const onReset = useCallback(() => {
onChange(getDefaultFloatRangeStartStepCountGenerator());
}, [onChange]);
return (
<Flex gap={1} alignItems="flex-end" p={1}>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.start')}</FormLabel>
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.count')}</FormLabel>
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
</FormControl>
<FormControl orientation="vertical" gap={1}>
<FormLabel m={0}>{t('common.step')}</FormLabel>
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
</FormControl>
<IconButton
onClick={onReset}
aria-label={t('common.reset')}
icon={<PiArrowCounterClockwiseBold />}
variant="ghost"
/>
</Flex>
);
});
FloatRangeGenerator.displayName = 'FloatRangeGenerator';

View File

@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
import { DndImage } from 'features/dnd/DndImage';
import { DndImageIcon } from 'features/dnd/DndImageIcon';
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type { ImageField } from 'features/nodes/types/common';
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
@@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo(
);
const onRemoveImage = useCallback(
(imageName: string) => {
removeImageFromNodeImageFieldCollectionAction({
imageName,
fieldIdentifier: { nodeId, fieldName: field.name },
dispatch: store.dispatch,
getState: store.getState,
});
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, nodeId, store.dispatch, store.getState]
[field.name, field.value, nodeId, store]
);
return (
@@ -90,7 +87,7 @@ export const ImageFieldCollectionInputComponent = memo(
isError={isInvalid}
onUpload={onUpload}
fontSize={24}
variant="outline"
variant="ghost"
/>
)}
{field.value && field.value.length > 0 && (
@@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo(
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
{field.value.map(({ image_name }) => (
<GridItem key={image_name} position="relative" className="nodrag">
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
</GridItem>
))}
</Grid>
@@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo(
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
const ImageGridItemContent = memo(
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
const query = useGetImageDTOQuery(imageName);
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
const query = useGetImageDTOQuery(value.image_name);
const onClickRemove = useCallback(() => {
onRemoveImage(imageName);
}, [imageName, onRemoveImage]);
onRemoveImage(index);
}, [index, onRemoveImage]);
if (query.isLoading) {
return <IAINoContentFallbackWithSpinner />;

View File

@@ -0,0 +1,320 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import {
Button,
CompositeNumberInput,
Divider,
Flex,
FormControl,
FormLabel,
Grid,
GridItem,
IconButton,
Switch,
Text,
} from '@invoke-ai/ui-library';
import { NUMPY_RAND_MAX } from 'app/constants';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import {
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionLockLinearViewToggled,
fieldNumberCollectionValueChanged,
} from 'features/nodes/store/nodesSlice';
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { isNil, round } from 'lodash-es';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const NumberFieldCollectionInputComponent = memo(
(
props:
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
) => {
const { nodeId, field, fieldTemplate, isLinearView } = props;
const store = useAppStore();
const { t } = useTranslation();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
const onRemoveNumber = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeNumber = useCallback(
(index: number, value: number) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddNumber = useCallback(() => {
const newValue = field.value ? [...field.value, 0] : [0];
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
const min = useMemo(() => {
let min = -NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.minimum)) {
min = fieldTemplate.minimum;
}
if (!isNil(fieldTemplate.exclusiveMinimum)) {
min = fieldTemplate.exclusiveMinimum + 0.01;
}
return min;
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
const max = useMemo(() => {
let max = NUMPY_RAND_MAX;
if (!isNil(fieldTemplate.maximum)) {
max = fieldTemplate.maximum;
}
if (!isNil(fieldTemplate.exclusiveMaximum)) {
max = fieldTemplate.exclusiveMaximum - 0.01;
}
return max;
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
const step = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.1;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const fineStep = useMemo(() => {
if (isNil(fieldTemplate.multipleOf)) {
return isIntegerField ? 1 : 0.01;
}
return fieldTemplate.multipleOf;
}, [fieldTemplate.multipleOf, isIntegerField]);
const toggleGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onChangeGenerator = useCallback(
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
},
[field.name, nodeId, store]
);
const onCommitGenerator = useCallback(() => {
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const onToggleLockLinearView = useCallback(() => {
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
}, [field.name, nodeId, store]);
const valuesAsString = useMemo(() => {
const resolvedValue = resolveNumberFieldCollectionValue(field);
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
}, [field]);
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="auto"
maxH={64}
alignItems="stretch"
justifyContent="center"
p={1}
sx={sx}
data-error={isInvalid}
borderRadius="base"
flexDir="column"
gap={1}
>
<Flex w="full" gap={2}>
{!field.generator && (
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
{t('nodes.addValue')}
</Button>
)}
{field.generator && isLockedOnLinearView && (
<Button
tooltip={
<Flex p={1} flexDir="column">
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
<Text fontFamily="monospace">{valuesAsString}</Text>
</Flex>
}
onClick={onCommitGenerator}
variant="ghost"
flexGrow={1}
size="sm"
>
{t('nodes.commitValues')}
</Button>
)}
{isLockedOnLinearView && (
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
</FormControl>
)}
{!isLinearView && (
<IconButton
onClick={onToggleLockLinearView}
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
variant="ghost"
size="sm"
/>
)}
</Flex>
{!field.generator && field.value && field.value.length > 0 && (
<>
{!(field.lockLinearView && isLinearView) && <Divider />}
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
{field.value.map((value, index) => (
<NumberListItemContent
key={index}
value={value}
index={index}
min={min}
max={max}
step={step}
fineStep={fineStep}
isIntegerField={isIntegerField}
onRemoveNumber={onRemoveNumber}
onChangeNumber={onChangeNumber}
/>
))}
</Grid>
</OverlayScrollbarsComponent>
</>
)}
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
<>
{!(field.lockLinearView && isLinearView) && <Divider />}
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
</>
)}
</Flex>
);
}
);
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
type NumberListItemContentProps = {
value: number;
index: number;
isIntegerField: boolean;
min: number;
max: number;
step: number;
fineStep: number;
onRemoveNumber: (index: number) => void;
onChangeNumber: (index: number, value: number) => void;
};
const NumberListItemContent = memo(
({
value,
index,
isIntegerField,
min,
max,
step,
fineStep,
onRemoveNumber,
onChangeNumber,
}: NumberListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveNumber(index);
}, [index, onRemoveNumber]);
const onChange = useCallback(
(v: number) => {
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
},
[index, isIntegerField, onChangeNumber]
);
return (
<>
<GridItem>
<FormLabel ps={1} m={0}>
{index + 1}.
</FormLabel>
</GridItem>
<GridItem>
<CompositeNumberInput
onChange={onChange}
value={value}
min={min}
max={max}
step={step}
fineStep={fineStep}
className="nodrag"
flexGrow={1}
/>
</GridItem>
<GridItem>
<IconButton
tabIndex={-1}
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.delete')}
/>
</GridItem>
</>
);
}
);
NumberListItemContent.displayName = 'NumberListItemContent';

View File

@@ -0,0 +1,152 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Flex, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
import { useAppStore } from 'app/store/nanostores/store';
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
import type {
StringFieldCollectionInputInstance,
StringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
import type { ChangeEvent } from 'react';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold, PiXBold } from 'react-icons/pi';
import type { FieldComponentProps } from './types';
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
const sx = {
borderWidth: 1,
'&[data-error=true]': {
borderColor: 'error.500',
borderStyle: 'solid',
},
} satisfies SystemStyleObject;
export const StringFieldCollectionInputComponent = memo(
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
const { nodeId, field } = props;
const store = useAppStore();
const isInvalid = useFieldIsInvalid(nodeId, field.name);
const onRemoveString = useCallback(
(index: number) => {
const newValue = field.value ? [...field.value] : [];
newValue.splice(index, 1);
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onChangeString = useCallback(
(index: number, value: string) => {
const newValue = field.value ? [...field.value] : [];
newValue[index] = value;
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
},
[field.name, field.value, nodeId, store]
);
const onAddString = useCallback(() => {
const newValue = field.value ? [...field.value, ''] : [''];
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
}, [field.name, field.value, nodeId, store]);
return (
<Flex
className="nodrag"
position="relative"
w="full"
h="full"
maxH={64}
alignItems="stretch"
justifyContent="center"
>
{(!field.value || field.value.length === 0) && (
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
<IconButton
w="full"
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
</Box>
)}
{field.value && field.value.length > 0 && (
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
<OverlayScrollbarsComponent
className="nowheel"
defer
style={overlayScrollbarsStyles}
options={overlayscrollbarsOptions}
>
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
<IconButton
onClick={onAddString}
aria-label="Add Item"
icon={<PiPlusBold />}
variant="ghost"
size="sm"
/>
{field.value.map((value, index) => (
<GridItem key={index} position="relative" className="nodrag">
<StringListItemContent
value={value}
index={index}
onRemoveString={onRemoveString}
onChangeString={onChangeString}
/>
</GridItem>
))}
</Grid>
</OverlayScrollbarsComponent>
</Box>
)}
</Flex>
);
}
);
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
type StringListItemContentProps = {
value: string;
index: number;
onRemoveString: (index: number) => void;
onChangeString: (index: number, value: string) => void;
};
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
const { t } = useTranslation();
const onClickRemove = useCallback(() => {
onRemoveString(index);
}, [index, onRemoveString]);
const onChange = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
onChangeString(index, e.target.value);
},
[index, onChangeString]
);
return (
<Flex alignItems="center" gap={1}>
<Input size="xs" resize="none" value={value} onChange={onChange} />
<IconButton
size="sm"
variant="link"
alignSelf="stretch"
onClick={onClickRemove}
icon={<PiXBold />}
aria-label={t('common.remove')}
tooltip={t('common.remove')}
/>
</Flex>
);
});
StringListItemContent.displayName = 'StringListItemContent';

View File

@@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
nodeId: string;
field: V;
fieldTemplate: T;
isLinearView: boolean;
};

View File

@@ -1,12 +1,14 @@
import type { SystemStyleObject } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
import type { MouseEvent } from 'react';
import { memo, useCallback, useEffect, useState } from 'react';
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
type Props = {
@@ -17,6 +19,8 @@ type Props = {
const NodeTitle = ({ nodeId, title }: Props) => {
const dispatch = useAppDispatch();
const label = useNodeLabel(nodeId);
const batchGroupId = useBatchGroupId(nodeId);
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
const templateTitle = useNodeTemplateTitle(nodeId);
const { t } = useTranslation();
@@ -29,6 +33,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
[dispatch, nodeId, title, templateTitle, label, t]
);
const localTitleWithBatchGroupId = useMemo(() => {
if (!batchGroupId) {
return localTitle;
}
if (batchGroupId === 'None') {
return `${localTitle} (${t('nodes.noBatchGroup')})`;
}
return `${localTitle} (${batchGroupId})`;
}, [batchGroupId, localTitle, t]);
const handleChange = useCallback((newTitle: string) => {
setLocalTitle(newTitle);
}, []);
@@ -50,7 +64,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
w="full"
h="full"
>
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
<Preview
fontSize="sm"
p={0}
w="full"
noOfLines={1}
color={batchGroupColorToken}
fontWeight={batchGroupId ? 'semibold' : undefined}
>
{localTitleWithBatchGroupId}
</Preview>
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
<EditableControls />
</Editable>
@@ -60,6 +83,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
export default memo(NodeTitle);
const Preview = (props: TextProps) => {
const { isEditing } = useEditableControls();
if (isEditing) {
return null;
}
return <Text {...props} />;
};
function EditableControls() {
const { isEditing, getEditButtonProps } = useEditableControls();
const handleDoubleClick = useCallback(

View File

@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
import ViewportControls from './ViewportControls';
const BottomLeftPanel = () => (
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
<ViewportControls />
<NodeOpacitySlider />
</Flex>

View File

@@ -20,7 +20,7 @@ const MinimapPanel = () => {
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
return (
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
{shouldShowMinimapPanel && (
<ChakraMiniMap
pannable

View File

@@ -12,7 +12,7 @@ import { memo } from 'react';
const TopCenterPanel = () => {
const name = useAppSelector(selectWorkflowName);
return (
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
<Flex gap="2">
<AddNodeButton />
<UpdateNodesButton />

View File

@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
</Flex>
</Tooltip>
</Flex>
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
</Flex>
);
};

View File

@@ -0,0 +1,22 @@
import { useMemo } from 'react';
export const useBatchGroupColorToken = (batchGroupId?: string) => {
const batchGroupColorToken = useMemo(() => {
switch (batchGroupId) {
case 'Group 1':
return 'invokeGreen.300';
case 'Group 2':
return 'invokeBlue.300';
case 'Group 3':
return 'invokePurple.200';
case 'Group 4':
return 'invokeRed.300';
case 'Group 5':
return 'invokeYellow.300';
default:
return undefined;
}
}, [batchGroupId]);
return batchGroupColorToken;
};

View File

@@ -0,0 +1,19 @@
import { useNode } from 'features/nodes/hooks/useNode';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { useMemo } from 'react';
export const useBatchGroupId = (nodeId: string) => {
const node = useNode(nodeId);
const batchGroupId = useMemo(() => {
if (!isInvocationNode(node)) {
return;
}
if (!isBatchNode(node)) {
return;
}
return node.data.inputs['batch_group_id']?.value as string;
}, [node]);
return batchGroupId;
};

View File

@@ -3,7 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import { useMemo } from 'react';
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
@@ -35,13 +49,27 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
}
// Else special handling for individual field types
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
// Image collections may have min or max item counts
if (template.minItems !== undefined && field.value.length < template.minItems) {
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
return true;
}
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
if (validateNumberFieldCollectionValue(field, template).length > 0) {
return true;
}
}

View File

@@ -1,8 +1,9 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { isEqual } from 'lodash-es';
import { useCallback, useMemo } from 'react';
@@ -10,19 +11,38 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
const dispatch = useAppDispatch();
const selectOriginalExposedFieldValues = useMemo(
() =>
createSelector(
selectWorkflowSlice,
(workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
),
[nodeId, fieldName]
);
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
const value = useFieldValue(nodeId, fieldName);
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
const field = useFieldInputInstance(nodeId, fieldName);
const isValueChanged = useMemo(() => {
if (!field) {
// Field is not found, so it is not changed
return false;
}
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
return !isEqual(field.generator, exposedField.field.generator);
}
return !isEqual(field.value, exposedField?.field.value);
}, [field, exposedField]);
const onReset = useCallback(() => {
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
}, [dispatch, fieldName, nodeId, originalValue]);
if (!exposedField) {
return;
}
const { value } = exposedField.field;
const generator =
isIntegerFieldCollectionInputInstance(exposedField.field) ||
isFloatFieldCollectionInputInstance(exposedField.field)
? exposedField.field.generator
: undefined;
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
}, [dispatch, fieldName, nodeId, exposedField]);
return { originalValue, isValueChanged, onReset };
return { originalValue: exposedField, isValueChanged, onReset };
};

View File

@@ -19,6 +19,7 @@ import type {
FluxVAEModelFieldValue,
ImageFieldCollectionValue,
ImageFieldValue,
IntegerFieldCollectionValue,
IntegerFieldValue,
IPAdapterModelFieldValue,
LoRAModelFieldValue,
@@ -28,12 +29,15 @@ import type {
SDXLRefinerModelFieldValue,
SpandrelImageToImageModelFieldValue,
StatefulFieldValue,
StringFieldCollectionValue,
StringFieldValue,
T2IAdapterModelFieldValue,
T5EncoderModelFieldValue,
VAEModelFieldValue,
} from 'features/nodes/types/field';
import {
isFloatFieldCollectionInputInstance,
isIntegerFieldCollectionInputInstance,
zBoardFieldValue,
zBooleanFieldValue,
zCLIPEmbedModelFieldValue,
@@ -43,10 +47,12 @@ import {
zControlLoRAModelFieldValue,
zControlNetModelFieldValue,
zEnumFieldValue,
zFloatFieldCollectionValue,
zFloatFieldValue,
zFluxVAEModelFieldValue,
zImageFieldCollectionValue,
zImageFieldValue,
zIntegerFieldCollectionValue,
zIntegerFieldValue,
zIPAdapterModelFieldValue,
zLoRAModelFieldValue,
@@ -56,11 +62,22 @@ import {
zSDXLRefinerModelFieldValue,
zSpandrelImageToImageModelFieldValue,
zStatefulFieldValue,
zStringFieldCollectionValue,
zStringFieldValue,
zT2IAdapterModelFieldValue,
zT5EncoderModelFieldValue,
zVAEModelFieldValue,
} from 'features/nodes/types/field';
import type {
FloatRangeStartStepCountGenerator,
IntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import {
floatRangeStartStepCountGenerator,
getDefaultFloatRangeStartStepCountGenerator,
getDefaultIntegerRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
import { atom, computed } from 'nanostores';
@@ -78,11 +95,22 @@ const initialNodesState: NodesState = {
edges: [],
};
type FieldValueAction<T extends FieldValue> = PayloadAction<{
nodeId: string;
fieldName: string;
value: T;
}>;
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
{
nodeId: string;
fieldName: string;
value: T;
} & U
>;
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
return node.data?.inputs[fieldName];
};
const fieldValueReducer = <T extends FieldValue>(
state: NodesState,
@@ -90,17 +118,24 @@ const fieldValueReducer = <T extends FieldValue>(
schema: z.ZodTypeAny
) => {
const { nodeId, fieldName, value } = action.payload;
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
const node = state.nodes?.[nodeIndex];
if (!isInvocationNode(node)) {
return;
}
const input = node.data?.inputs[fieldName];
const field = selectField(state, nodeId, fieldName);
const result = schema.safeParse(value);
if (!input || nodeIndex < 0 || !result.success) {
if (!field || !result.success) {
return;
}
input.value = result.data;
field.value = result.data;
// Special handling if the field value is being reset
if (result.data === undefined) {
if (isFloatFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultFloatRangeStartStepCountGenerator();
}
} else if (isIntegerFieldCollectionInputInstance(field)) {
if (field.lockLinearView && field.generator) {
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
}
}
}
};
export const nodesSlice = createSlice({
@@ -305,15 +340,123 @@ export const nodesSlice = createSlice({
}
node.data.notes = notes;
},
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
fieldValueReducer(state, action, zStatefulFieldValue);
fieldValueReset: (
state,
action: FieldValueAction<
StatefulFieldValue,
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
>
) => {
const { nodeId, fieldName, value, generator } = action.payload;
const field = selectField(state, nodeId, fieldName);
const result = zStatefulFieldValue.safeParse(value);
if (!field || !result.success) {
return;
}
field.value = result.data;
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
field.generator = generator;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generator?.type === 'integer-range-generator-start-step-count'
) {
field.generator = generator;
}
},
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
fieldValueReducer(state, action, zStringFieldValue);
},
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
fieldValueReducer(state, action, zStringFieldCollectionValue);
},
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
},
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
},
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (isFloatFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
} else if (isIntegerFieldCollectionInputInstance(field)) {
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorStateChanged: (
state,
action: PayloadAction<{
nodeId: string;
fieldName: string;
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
}>
) => {
const { nodeId, fieldName, generatorState } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
generatorState.type === 'float-range-generator-start-step-count'
) {
field.generator = generatorState;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
generatorState.type === 'integer-range-generator-start-step-count'
) {
field.generator = generatorState;
} else {
// This should never happen
}
},
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (
isFloatFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'float-range-generator-start-step-count'
) {
field.value = floatRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else if (
isIntegerFieldCollectionInputInstance(field) &&
field.generator &&
field.generator.type === 'integer-range-generator-start-step-count'
) {
field.value = integerRangeStartStepCountGenerator(field.generator);
field.generator = undefined;
} else {
// This should never happen
}
},
fieldNumberCollectionLockLinearViewToggled: (
state,
action: PayloadAction<{ nodeId: string; fieldName: string }>
) => {
const { nodeId, fieldName } = action.payload;
const field = selectField(state, nodeId, fieldName);
if (!field) {
return;
}
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
return;
}
field.lockLinearView = !field.lockLinearView;
},
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
fieldValueReducer(state, action, zBooleanFieldValue);
},
@@ -435,9 +578,15 @@ export const {
fieldModelIdentifierValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldNumberCollectionGeneratorToggled,
fieldNumberCollectionGeneratorStateChanged,
fieldNumberCollectionGeneratorCommitted,
fieldNumberCollectionLockLinearViewToggled,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,
@@ -546,9 +695,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
fieldLoRAModelValueChanged,
fieldMainModelValueChanged,
fieldNumberValueChanged,
fieldNumberCollectionValueChanged,
fieldRefinerModelValueChanged,
fieldSchedulerValueChanged,
fieldStringValueChanged,
fieldStringCollectionValueChanged,
fieldVaeModelValueChanged,
fieldT5EncoderValueChanged,
fieldCLIPEmbedValueChanged,

View File

@@ -1,8 +1,8 @@
import type {
FieldIdentifier,
FieldInputInstance,
FieldInputTemplate,
FieldOutputTemplate,
StatefulFieldValue,
} from 'features/nodes/types/field';
import type {
AnyNode,
@@ -31,15 +31,15 @@ export type NodesState = {
};
export type WorkflowMode = 'edit' | 'view';
export type FieldIdentifierWithValue = FieldIdentifier & {
value: StatefulFieldValue;
export type FieldIdentifierWithInstance = FieldIdentifier & {
field: FieldInputInstance;
};
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
_version: 1;
_version: 2;
isTouched: boolean;
mode: WorkflowMode;
originalExposedFieldValues: FieldIdentifierWithValue[];
originalExposedFieldValues: FieldIdentifierWithInstance[];
searchTerm: string;
orderBy?: WorkflowRecordOrderBy;
orderDirection: SQLiteDirection;

View File

@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
import { workflowLoaded } from 'features/nodes/store/actions';
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
import type {
FieldIdentifierWithValue,
FieldIdentifierWithInstance,
WorkflowMode,
WorkflowsState as WorkflowState,
} from 'features/nodes/store/types';
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
};
const initialWorkflowState: WorkflowState = {
_version: 1,
_version: 2,
isTouched: false,
mode: 'view',
originalExposedFieldValues: [],
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
const { id, isOpen } = action.payload;
state.categorySections[id] = isOpen;
},
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
state.exposedFields = uniqBy(
state.exposedFields.concat(omit(action.payload, 'value')),
(field) => `${field.nodeId}-${field.fieldName}`
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
builder.addCase(workflowLoaded, (state, action) => {
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
workflowExtra.exposedFields.forEach((field) => {
const node = nodes.find((n) => n.id === field.nodeId);
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
const node = nodes.find((n) => n.id === nodeId);
if (!isInvocationNode(node)) {
return;
}
const input = node.data.inputs[field.fieldName];
const field = node.data.inputs[fieldName];
if (!input) {
if (!field) {
return;
}
const originalExposedFieldValue = {
nodeId: field.nodeId,
fieldName: field.fieldName,
value: input.value,
nodeId,
fieldName,
field,
};
originalExposedFieldValues.push(originalExposedFieldValue);
});
@@ -243,6 +243,9 @@ const migrateWorkflowState = (state: any): any => {
if (!('_version' in state)) {
state._version = 1;
}
if (state._version === 1) {
return deepClone(initialWorkflowState);
}
return state;
};

View File

@@ -1,3 +1,8 @@
import {
zFloatRangeStartStepCountGenerator,
zIntegerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
import { z } from 'zod';
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
@@ -78,14 +83,35 @@ const zIntegerFieldType = zFieldTypeBase.extend({
name: z.literal('IntegerField'),
originalType: zStatelessFieldType.optional(),
});
const zIntegerCollectionFieldType = z.object({
name: z.literal('IntegerField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isIntegerCollectionFieldType = buildTypeGuard(zIntegerCollectionFieldType);
const zFloatFieldType = zFieldTypeBase.extend({
name: z.literal('FloatField'),
originalType: zStatelessFieldType.optional(),
});
const zFloatCollectionFieldType = z.object({
name: z.literal('FloatField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isFloatCollectionFieldType = buildTypeGuard(zFloatCollectionFieldType);
const zStringFieldType = zFieldTypeBase.extend({
name: z.literal('StringField'),
originalType: zStatelessFieldType.optional(),
});
const zStringCollectionFieldType = z.object({
name: z.literal('StringField'),
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isStringCollectionFieldType = buildTypeGuard(zStringCollectionFieldType);
const zBooleanFieldType = zFieldTypeBase.extend({
name: z.literal('BooleanField'),
originalType: zStatelessFieldType.optional(),
@@ -103,9 +129,7 @@ const zImageCollectionFieldType = z.object({
cardinality: z.literal(COLLECTION),
originalType: zStatelessFieldType.optional(),
});
export const isImageCollectionFieldType = (
fieldType: FieldType
): fieldType is z.infer<typeof zImageCollectionFieldType> => zImageCollectionFieldType.safeParse(fieldType).success;
export const isImageCollectionFieldType = buildTypeGuard(zImageCollectionFieldType);
const zBoardFieldType = zFieldTypeBase.extend({
name: z.literal('BoardField'),
originalType: zStatelessFieldType.optional(),
@@ -254,10 +278,48 @@ const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
zIntegerFieldInputInstance.safeParse(val).success;
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
zIntegerFieldInputTemplate.safeParse(val).success;
export const isIntegerFieldInputInstance = buildTypeGuard(zIntegerFieldInputInstance);
export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemplate);
// #endregion
// #region IntegerField Collection
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zIntegerFieldCollectionValue,
generator: zIntegerRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zIntegerCollectionFieldType,
originalType: zFieldType.optional(),
default: zIntegerFieldCollectionValue,
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
multipleOf: z.number().int().optional(),
maximum: z.number().int().optional(),
exclusiveMaximum: z.number().int().optional(),
minimum: z.number().int().optional(),
exclusiveMinimum: z.number().int().optional(),
})
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zIntegerFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zIntegerCollectionFieldType,
});
export type IntegerFieldCollectionValue = z.infer<typeof zIntegerFieldCollectionValue>;
export type IntegerFieldCollectionInputInstance = z.infer<typeof zIntegerFieldCollectionInputInstance>;
export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCollectionInputTemplate>;
export const isIntegerFieldCollectionInputInstance = buildTypeGuard(zIntegerFieldCollectionInputInstance);
export const isIntegerFieldCollectionInputTemplate = buildTypeGuard(zIntegerFieldCollectionInputTemplate);
// #endregion
// #region FloatField
@@ -282,10 +344,48 @@ const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
zFloatFieldInputInstance.safeParse(val).success;
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
zFloatFieldInputTemplate.safeParse(val).success;
export const isFloatFieldInputInstance = buildTypeGuard(zFloatFieldInputInstance);
export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate);
// #endregion
// #region FloatField Collection
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zFloatFieldCollectionValue,
generator: zFloatRangeStartStepCountGenerator.optional(),
lockLinearView: z.boolean().default(false),
});
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zFloatCollectionFieldType,
originalType: zFieldType.optional(),
default: zFloatFieldCollectionValue,
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
multipleOf: z.number().int().optional(),
maximum: z.number().optional(),
exclusiveMaximum: z.number().optional(),
minimum: z.number().optional(),
exclusiveMinimum: z.number().optional(),
})
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zFloatCollectionFieldType,
});
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
export const isFloatFieldCollectionInputTemplate = buildTypeGuard(zFloatFieldCollectionInputTemplate);
// #endregion
// #region StringField
@@ -315,13 +415,55 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringFieldType,
});
// #region StringField Collection
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
value: zStringFieldCollectionValue,
});
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
.extend({
type: zStringCollectionFieldType,
originalType: zFieldType.optional(),
default: zStringFieldCollectionValue,
maxLength: z.number().int().gte(0).optional(),
minLength: z.number().int().gte(0).optional(),
maxItems: z.number().int().gte(0).optional(),
minItems: z.number().int().gte(0).optional(),
})
.refine(
(val) => {
if (val.maxLength !== undefined && val.minLength !== undefined) {
return val.maxLength >= val.minLength;
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
)
.refine(
(val) => {
if (val.maxItems !== undefined && val.minItems !== undefined) {
return val.maxItems >= val.minItems;
}
return true;
},
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
type: zStringCollectionFieldType,
});
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
export const isStringFieldCollectionInputInstance = buildTypeGuard(zStringFieldCollectionInputInstance);
export const isStringFieldCollectionInputTemplate = buildTypeGuard(zStringFieldCollectionInputTemplate);
// #endregion
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
zStringFieldInputInstance.safeParse(val).success;
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
zStringFieldInputTemplate.safeParse(val).success;
export const isStringFieldInputInstance = buildTypeGuard(zStringFieldInputInstance);
export const isStringFieldInputTemplate = buildTypeGuard(zStringFieldInputTemplate);
// #endregion
// #region BooleanField
@@ -341,10 +483,8 @@ const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
zBooleanFieldInputInstance.safeParse(val).success;
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
zBooleanFieldInputTemplate.safeParse(val).success;
export const isBooleanFieldInputInstance = buildTypeGuard(zBooleanFieldInputInstance);
export const isBooleanFieldInputTemplate = buildTypeGuard(zBooleanFieldInputTemplate);
// #endregion
// #region EnumField
@@ -366,10 +506,8 @@ const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
zEnumFieldInputInstance.safeParse(val).success;
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
zEnumFieldInputTemplate.safeParse(val).success;
export const isEnumFieldInputInstance = buildTypeGuard(zEnumFieldInputInstance);
export const isEnumFieldInputTemplate = buildTypeGuard(zEnumFieldInputTemplate);
// #endregion
// #region ImageField
@@ -388,10 +526,8 @@ const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
zImageFieldInputInstance.safeParse(val).success;
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
zImageFieldInputTemplate.safeParse(val).success;
export const isImageFieldInputInstance = buildTypeGuard(zImageFieldInputInstance);
export const isImageFieldInputTemplate = buildTypeGuard(zImageFieldInputTemplate);
// #endregion
// #region ImageField Collection
@@ -414,7 +550,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
}
return true;
},
{ message: 'maxLength must be greater than or equal to minLength' }
{ message: 'maxItems must be greater than or equal to minItems' }
);
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
@@ -423,10 +559,8 @@ const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
zImageFieldCollectionInputInstance.safeParse(val).success;
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
zImageFieldCollectionInputTemplate.safeParse(val).success;
export const isImageFieldCollectionInputInstance = buildTypeGuard(zImageFieldCollectionInputInstance);
export const isImageFieldCollectionInputTemplate = buildTypeGuard(zImageFieldCollectionInputTemplate);
// #endregion
// #region BoardField
@@ -446,10 +580,8 @@ const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
zBoardFieldInputInstance.safeParse(val).success;
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
zBoardFieldInputTemplate.safeParse(val).success;
export const isBoardFieldInputInstance = buildTypeGuard(zBoardFieldInputInstance);
export const isBoardFieldInputTemplate = buildTypeGuard(zBoardFieldInputTemplate);
// #endregion
// #region ColorField
@@ -469,10 +601,8 @@ const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
zColorFieldInputInstance.safeParse(val).success;
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
zColorFieldInputTemplate.safeParse(val).success;
export const isColorFieldInputInstance = buildTypeGuard(zColorFieldInputInstance);
export const isColorFieldInputTemplate = buildTypeGuard(zColorFieldInputTemplate);
// #endregion
// #region MainModelField
@@ -492,10 +622,8 @@ const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
zMainModelFieldInputInstance.safeParse(val).success;
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
zMainModelFieldInputTemplate.safeParse(val).success;
export const isMainModelFieldInputInstance = buildTypeGuard(zMainModelFieldInputInstance);
export const isMainModelFieldInputTemplate = buildTypeGuard(zMainModelFieldInputTemplate);
// #endregion
// #region ModelIdentifierField
@@ -514,10 +642,8 @@ const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
zModelIdentifierFieldInputInstance.safeParse(val).success;
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
zModelIdentifierFieldInputTemplate.safeParse(val).success;
export const isModelIdentifierFieldInputInstance = buildTypeGuard(zModelIdentifierFieldInputInstance);
export const isModelIdentifierFieldInputTemplate = buildTypeGuard(zModelIdentifierFieldInputTemplate);
// #endregion
// #region SDXLMainModelField
@@ -536,10 +662,8 @@ const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
zSDXLMainModelFieldInputInstance.safeParse(val).success;
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
export const isSDXLMainModelFieldInputInstance = buildTypeGuard(zSDXLMainModelFieldInputInstance);
export const isSDXLMainModelFieldInputTemplate = buildTypeGuard(zSDXLMainModelFieldInputTemplate);
// #endregion
// #region SD3MainModelField
@@ -558,10 +682,8 @@ const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
zSD3MainModelFieldInputInstance.safeParse(val).success;
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
zSD3MainModelFieldInputTemplate.safeParse(val).success;
export const isSD3MainModelFieldInputInstance = buildTypeGuard(zSD3MainModelFieldInputInstance);
export const isSD3MainModelFieldInputTemplate = buildTypeGuard(zSD3MainModelFieldInputTemplate);
// #endregion
@@ -581,10 +703,8 @@ const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
});
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
zFluxMainModelFieldInputInstance.safeParse(val).success;
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
zFluxMainModelFieldInputTemplate.safeParse(val).success;
export const isFluxMainModelFieldInputInstance = buildTypeGuard(zFluxMainModelFieldInputInstance);
export const isFluxMainModelFieldInputTemplate = buildTypeGuard(zFluxMainModelFieldInputTemplate);
// #endregion
@@ -606,10 +726,8 @@ const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
export const isSDXLRefinerModelFieldInputInstance = buildTypeGuard(zSDXLRefinerModelFieldInputInstance);
export const isSDXLRefinerModelFieldInputTemplate = buildTypeGuard(zSDXLRefinerModelFieldInputTemplate);
// #endregion
// #region VAEModelField
@@ -629,10 +747,8 @@ const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
zVAEModelFieldInputInstance.safeParse(val).success;
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
zVAEModelFieldInputTemplate.safeParse(val).success;
export const isVAEModelFieldInputInstance = buildTypeGuard(zVAEModelFieldInputInstance);
export const isVAEModelFieldInputTemplate = buildTypeGuard(zVAEModelFieldInputTemplate);
// #endregion
// #region LoRAModelField
@@ -652,10 +768,8 @@ const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
zLoRAModelFieldInputInstance.safeParse(val).success;
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
zLoRAModelFieldInputTemplate.safeParse(val).success;
export const isLoRAModelFieldInputInstance = buildTypeGuard(zLoRAModelFieldInputInstance);
export const isLoRAModelFieldInputTemplate = buildTypeGuard(zLoRAModelFieldInputTemplate);
// #endregion
// #region ControlNetModelField
@@ -675,10 +789,8 @@ const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
zControlNetModelFieldInputInstance.safeParse(val).success;
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
zControlNetModelFieldInputTemplate.safeParse(val).success;
export const isControlNetModelFieldInputInstance = buildTypeGuard(zControlNetModelFieldInputInstance);
export const isControlNetModelFieldInputTemplate = buildTypeGuard(zControlNetModelFieldInputTemplate);
// #endregion
// #region IPAdapterModelField
@@ -698,10 +810,8 @@ const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
zIPAdapterModelFieldInputInstance.safeParse(val).success;
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
export const isIPAdapterModelFieldInputInstance = buildTypeGuard(zIPAdapterModelFieldInputInstance);
export const isIPAdapterModelFieldInputTemplate = buildTypeGuard(zIPAdapterModelFieldInputTemplate);
// #endregion
// #region T2IAdapterField
@@ -721,10 +831,8 @@ const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
export const isT2IAdapterModelFieldInputInstance = buildTypeGuard(zT2IAdapterModelFieldInputInstance);
export const isT2IAdapterModelFieldInputTemplate = buildTypeGuard(zT2IAdapterModelFieldInputTemplate);
// #endregion
// #region SpandrelModelToModelField
@@ -744,14 +852,12 @@ const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.e
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
export const isSpandrelImageToImageModelFieldInputInstance = (
val: unknown
): val is SpandrelImageToImageModelFieldInputInstance =>
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
export const isSpandrelImageToImageModelFieldInputTemplate = (
val: unknown
): val is SpandrelImageToImageModelFieldInputTemplate =>
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
export const isSpandrelImageToImageModelFieldInputInstance = buildTypeGuard(
zSpandrelImageToImageModelFieldInputInstance
);
export const isSpandrelImageToImageModelFieldInputTemplate = buildTypeGuard(
zSpandrelImageToImageModelFieldInputTemplate
);
// #endregion
// #region T5EncoderModelField
@@ -770,10 +876,8 @@ export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
zT5EncoderModelFieldInputInstance.safeParse(val).success;
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
export const isT5EncoderModelFieldInputInstance = buildTypeGuard(zT5EncoderModelFieldInputInstance);
export const isT5EncoderModelFieldInputTemplate = buildTypeGuard(zT5EncoderModelFieldInputTemplate);
// #endregion
@@ -793,10 +897,8 @@ export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
zFluxVAEModelFieldInputInstance.safeParse(val).success;
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
export const isFluxVAEModelFieldInputInstance = buildTypeGuard(zFluxVAEModelFieldInputInstance);
export const isFluxVAEModelFieldInputTemplate = buildTypeGuard(zFluxVAEModelFieldInputTemplate);
// #endregion
@@ -816,10 +918,8 @@ export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
export const isCLIPEmbedModelFieldInputInstance = buildTypeGuard(zCLIPEmbedModelFieldInputInstance);
export const isCLIPEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPEmbedModelFieldInputTemplate);
// #endregion
@@ -839,10 +939,8 @@ export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
export const isCLIPLEmbedModelFieldInputInstance = buildTypeGuard(zCLIPLEmbedModelFieldInputInstance);
export const isCLIPLEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPLEmbedModelFieldInputTemplate);
// #endregion
@@ -862,10 +960,8 @@ export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
export const isCLIPGEmbedModelFieldInputInstance = buildTypeGuard(zCLIPGEmbedModelFieldInputInstance);
export const isCLIPGEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPGEmbedModelFieldInputTemplate);
// #endregion
@@ -885,10 +981,8 @@ export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldVal
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
zControlLoRAModelFieldInputInstance.safeParse(val).success;
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
export const isControlLoRAModelFieldInputInstance = buildTypeGuard(zControlLoRAModelFieldInputInstance);
export const isControlLoRAModelFieldInputTemplate = buildTypeGuard(zControlLoRAModelFieldInputTemplate);
// #endregion
@@ -909,10 +1003,8 @@ const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
zSchedulerFieldInputInstance.safeParse(val).success;
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
zSchedulerFieldInputTemplate.safeParse(val).success;
export const isSchedulerFieldInputInstance = buildTypeGuard(zSchedulerFieldInputInstance);
export const isSchedulerFieldInputTemplate = buildTypeGuard(zSchedulerFieldInputTemplate);
// #endregion
// #region StatelessField
@@ -963,8 +1055,11 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
// #region StatefulFieldValue & FieldValue
export const zStatefulFieldValue = z.union([
zIntegerFieldValue,
zIntegerFieldCollectionValue,
zFloatFieldValue,
zFloatFieldCollectionValue,
zStringFieldValue,
zStringFieldCollectionValue,
zBooleanFieldValue,
zEnumFieldValue,
zImageFieldValue,
@@ -1000,8 +1095,11 @@ export type FieldValue = z.infer<typeof zFieldValue>;
// #region StatefulFieldInputInstance & FieldInputInstance
const zStatefulFieldInputInstance = z.union([
zIntegerFieldInputInstance,
zIntegerFieldCollectionInputInstance,
zFloatFieldInputInstance,
zFloatFieldCollectionInputInstance,
zStringFieldInputInstance,
zStringFieldCollectionInputInstance,
zBooleanFieldInputInstance,
zEnumFieldInputInstance,
zImageFieldInputInstance,
@@ -1028,15 +1126,17 @@ const zStatefulFieldInputInstance = z.union([
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
zFieldInputInstance.safeParse(val).success;
export const isFieldInputInstance = buildTypeGuard(zFieldInputInstance);
// #endregion
// #region StatefulFieldInputTemplate & FieldInputTemplate
const zStatefulFieldInputTemplate = z.union([
zIntegerFieldInputTemplate,
zIntegerFieldCollectionInputTemplate,
zFloatFieldInputTemplate,
zFloatFieldCollectionInputTemplate,
zStringFieldInputTemplate,
zStringFieldCollectionInputTemplate,
zBooleanFieldInputTemplate,
zEnumFieldInputTemplate,
zImageFieldInputTemplate,
@@ -1067,15 +1167,17 @@ const zStatefulFieldInputTemplate = z.union([
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
zFieldInputTemplate.safeParse(val).success;
export const isFieldInputTemplate = buildTypeGuard(zFieldInputTemplate);
// #endregion
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
const zStatefulFieldOutputTemplate = z.union([
zIntegerFieldOutputTemplate,
zIntegerFieldCollectionOutputTemplate,
zFloatFieldOutputTemplate,
zFloatFieldCollectionOutputTemplate,
zStringFieldOutputTemplate,
zStringFieldCollectionOutputTemplate,
zBooleanFieldOutputTemplate,
zEnumFieldOutputTemplate,
zImageFieldOutputTemplate,

View File

@@ -0,0 +1,133 @@
import type {
FloatFieldCollectionInputInstance,
FloatFieldCollectionInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldCollectionValue,
IntegerFieldCollectionInputInstance,
IntegerFieldCollectionInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldCollectionValue,
} from 'features/nodes/types/field';
import {
floatRangeStartStepCountGenerator,
integerRangeStartStepCountGenerator,
} from 'features/nodes/types/generators';
import { t } from 'i18next';
export const validateImageFieldCollectionValue = (
value: NonNullable<ImageFieldCollectionValue>,
template: ImageFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
return reasons;
};
export const validateStringFieldCollectionValue = (
value: NonNullable<StringFieldCollectionValue>,
template: StringFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minLength, maxLength } = template;
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const str of value) {
if (maxLength !== undefined && str.length > maxLength) {
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
}
if (minLength !== undefined && str.length < minLength) {
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
}
}
return reasons;
};
export const resolveNumberFieldCollectionValue = (
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
): number[] | undefined => {
if (field.generator?.type === 'float-range-generator-start-step-count') {
return floatRangeStartStepCountGenerator(field.generator);
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
return integerRangeStartStepCountGenerator(field.generator);
} else {
return field.value;
}
};
export const validateNumberFieldCollectionValue = (
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
): string[] => {
const reasons: string[] = [];
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
const value = resolveNumberFieldCollectionValue(field);
if (value === undefined) {
reasons.push(t('parameters.invoke.collectionEmpty'));
return reasons;
}
const count = value.length;
// Image collections may have min or max items to validate
if (minItems !== undefined && minItems > 0 && count === 0) {
reasons.push(t('parameters.invoke.collectionEmpty'));
}
if (minItems !== undefined && count < minItems) {
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
}
if (maxItems !== undefined && count > maxItems) {
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
}
for (const num of value) {
if (maximum !== undefined && num > maximum) {
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
}
if (minimum !== undefined && num < minimum) {
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
}
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
}
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
}
if (multipleOf !== undefined && num % multipleOf !== 0) {
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
}
}
return reasons;
};

View File

@@ -0,0 +1,29 @@
import { z } from 'zod';
export const zFloatRangeStartStepCountGenerator = z.object({
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
start: z.number().default(0),
step: z.number().default(1),
count: z.number().int().default(10),
});
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
zFloatRangeStartStepCountGenerator.parse({});
export const zIntegerRangeStartStepCountGenerator = z.object({
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
start: z.number().int().default(0),
step: z.number().int().default(1),
count: z.number().int().default(10),
});
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
const { start, step, count } = generator;
return Array.from({ length: count }, (_, i) => start + i * step);
};
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
zIntegerRangeStartStepCountGenerator.parse({});

View File

@@ -91,3 +91,15 @@ const zInvocationNodeEdgeExtra = z.object({
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
// #endregion
export const isBatchNode = (node: InvocationNode) => {
switch (node.data.type) {
case 'image_batch':
case 'string_batch':
case 'integer_batch':
case 'float_batch':
return true;
default:
return false;
}
};

View File

@@ -1,7 +1,9 @@
import { logger } from 'app/logging/logger';
import type { NodesState } from 'features/nodes/store/types';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { omit, reduce } from 'lodash-es';
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { negate, omit, reduce } from 'lodash-es';
import type { AnyInvocation, Graph } from 'services/api/types';
import { v4 as uuidv4 } from 'uuid';
@@ -14,7 +16,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const { nodes, edges } = nodesState;
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
const filteredNodes = nodes.filter(isInvocationNode).filter(negate(isBatchNode));
// Reduce the node editor nodes into invocation graph nodes
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
@@ -25,7 +27,11 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
const transformedInputs = reduce(
inputs,
(inputsAccumulator, input, name) => {
inputsAccumulator[name] = input.value;
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
} else {
inputsAccumulator[name] = input.value;
}
return inputsAccumulator;
},

View File

@@ -11,11 +11,13 @@ import type {
EnumFieldInputTemplate,
FieldInputTemplate,
FieldType,
FloatFieldCollectionInputTemplate,
FloatFieldInputTemplate,
FluxMainModelFieldInputTemplate,
FluxVAEModelFieldInputTemplate,
ImageFieldCollectionInputTemplate,
ImageFieldInputTemplate,
IntegerFieldCollectionInputTemplate,
IntegerFieldInputTemplate,
IPAdapterModelFieldInputTemplate,
LoRAModelFieldInputTemplate,
@@ -28,12 +30,19 @@ import type {
SpandrelImageToImageModelFieldInputTemplate,
StatefulFieldType,
StatelessFieldInputTemplate,
StringFieldCollectionInputTemplate,
StringFieldInputTemplate,
T2IAdapterModelFieldInputTemplate,
T5EncoderModelFieldInputTemplate,
VAEModelFieldInputTemplate,
} from 'features/nodes/types/field';
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
import {
isFloatCollectionFieldType,
isImageCollectionFieldType,
isIntegerCollectionFieldType,
isStatefulFieldType,
isStringCollectionFieldType,
} from 'features/nodes/types/field';
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
import { isSchemaObject } from 'features/nodes/types/openapi';
import { t } from 'i18next';
@@ -77,6 +86,48 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
return template;
};
const buildIntegerFieldCollectionInputTemplate: FieldInputTemplateBuilder<IntegerFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: IntegerFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -111,6 +162,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
return template;
};
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: FloatFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
if (schemaObject.multipleOf !== undefined) {
template.multipleOf = schemaObject.multipleOf;
}
if (schemaObject.maximum !== undefined) {
template.maximum = schemaObject.maximum;
}
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
}
if (schemaObject.minimum !== undefined) {
template.minimum = schemaObject.minimum;
}
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
}
return template;
};
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -133,6 +226,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
return template;
};
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
schemaObject,
baseField,
fieldType,
}) => {
const template: StringFieldCollectionInputTemplate = {
...baseField,
type: fieldType,
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
};
if (schemaObject.minLength !== undefined) {
template.minLength = schemaObject.minLength;
}
if (schemaObject.maxLength !== undefined) {
template.maxLength = schemaObject.maxLength;
}
if (schemaObject.minItems !== undefined) {
template.minItems = schemaObject.minItems;
}
if (schemaObject.maxItems !== undefined) {
template.maxItems = schemaObject.maxItems;
}
return template;
};
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
schemaObject,
baseField,
@@ -569,12 +692,29 @@ export const buildFieldInputTemplate = (
if (isStatefulFieldType(fieldType)) {
if (isImageCollectionFieldType(fieldType)) {
fieldType;
return buildImageFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isStringCollectionFieldType(fieldType)) {
return buildStringFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isIntegerCollectionFieldType(fieldType)) {
return buildIntegerFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else if (isFloatCollectionFieldType(fieldType)) {
return buildFloatFieldCollectionInputTemplate({
schemaObject: fieldSchema,
baseField,
fieldType,
});
} else {
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
const template = builder({

View File

@@ -20,7 +20,7 @@ import { z } from 'zod';
* @param schema The zod schema to create a type guard from.
* @returns A type guard function for the schema.
*/
const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
};

View File

@@ -206,8 +206,16 @@ const QueueCountPredictionWorkflowsTab = memo(() => {
const iterationsCount = useAppSelector(selectIterations);
const text = useMemo(() => {
const generationCount = Math.min(batchSize * iterationsCount, 10000);
const iterations = t('queue.iterations', { count: iterationsCount });
if (batchSize === 'NO_BATCHES') {
const generationCount = Math.min(10000, iterationsCount);
const generations = t('queue.generations', { count: generationCount });
return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
}
if (batchSize === 'INVALID') {
return t('parameters.invoke.invalidBatchConfiguration');
}
const generationCount = Math.min(batchSize * iterationsCount, 10000);
const generations = t('queue.generations', { count: generationCount });
return `${batchSize} ${t('queue.batchSize')} \u00d7 ${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
}, [batchSize, iterationsCount, t]);

View File

@@ -1,4 +1,5 @@
import { createSelector } from '@reduxjs/toolkit';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import type { AppConfig } from 'app/types/invokeai';
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
@@ -18,14 +19,31 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
import type { NodesState, Templates } from 'features/nodes/store/types';
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
import { isInvocationNode } from 'features/nodes/types/invocation';
import {
isFloatFieldCollectionInputInstance,
isFloatFieldCollectionInputTemplate,
isImageFieldCollectionInputInstance,
isImageFieldCollectionInputTemplate,
isIntegerFieldCollectionInputInstance,
isIntegerFieldCollectionInputTemplate,
isStringFieldCollectionInputInstance,
isStringFieldCollectionInputTemplate,
} from 'features/nodes/types/field';
import {
resolveNumberFieldCollectionValue,
validateImageFieldCollectionValue,
validateNumberFieldCollectionValue,
validateStringFieldCollectionValue,
} from 'features/nodes/types/fieldValidators';
import type { InvocationNode } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
import { selectConfigSlice } from 'features/system/store/configSlice';
import i18n from 'i18next';
import { forEach, upperFirst } from 'lodash-es';
import { forEach, groupBy, negate, upperFirst } from 'lodash-es';
import { getConnectedEdges } from 'reactflow';
import { assert } from 'tsafe';
/**
* This file contains selectors and utilities for determining the app is ready to enqueue generations. The handling
@@ -61,11 +79,56 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
}
if (workflowSettings.shouldValidateGraph) {
if (!nodes.nodes.length) {
const invocationNodes = nodes.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
const nonBatchNodes = invocationNodes.filter(negate(isBatchNode));
if (!nonBatchNodes.length) {
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
}
nodes.nodes.forEach((node) => {
for (const node of batchNodes) {
if (nodes.edges.find((e) => e.source === node.id) === undefined) {
reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) });
}
}
if (batchNodes.length > 1) {
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
if (batchGroupId === 'None') {
// Ungrouped batch nodes may have differing collection sizes
continue;
}
// But grouped batch nodes must have the same collection size
const collectionSizes: number[] = [];
for (const node of batchNodes) {
if (node.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(node.data.inputs.images));
collectionSizes.push(node.data.inputs.images.value?.length ?? 0);
} else if (node.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(node.data.inputs.strings));
collectionSizes.push(node.data.inputs.strings.value?.length ?? 0);
} else if (node.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(node.data.inputs.floats));
collectionSizes.push(node.data.inputs.floats.value?.length ?? 0);
} else if (node.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(node.data.inputs.integers));
collectionSizes.push(node.data.inputs.integers.value?.length ?? 0);
}
}
if (collectionSizes.some((count) => count !== collectionSizes[0])) {
reasons.push({
content: i18n.t('parameters.invoke.batchNodeCollectionSizeMismatch', { batchGroupId }),
});
}
}
}
nonBatchNodes.forEach((node) => {
if (!isInvocationNode(node)) {
return;
}
@@ -91,45 +154,38 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
return;
}
const baseTKeyOptions = {
nodeLabel: node.data.label || nodeTemplate.title,
fieldLabel: field.label || fieldTemplate.title,
};
const prefix = `${node.data.label || nodeTemplate.title} -> ${field.label || fieldTemplate.title}`;
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
reasons.push({ content: i18n.t('parameters.invoke.missingInputForField', baseTKeyOptions) });
return;
reasons.push({ prefix, content: i18n.t('parameters.invoke.missingInputForField') });
} else if (
field.value &&
isImageFieldCollectionInputInstance(field) &&
isImageFieldCollectionInputTemplate(fieldTemplate)
) {
// Image collections may have min or max items to validate
// TODO(psyche): generalize this to other collection types
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
return;
}
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooFewItems', {
...baseTKeyOptions,
size: field.value.length,
minItems: fieldTemplate.minItems,
}),
});
return;
}
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
reasons.push({
content: i18n.t('parameters.invoke.collectionTooManyItems', {
...baseTKeyOptions,
size: field.value.length,
maxItems: fieldTemplate.maxItems,
}),
});
return;
}
const errors = validateImageFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isStringFieldCollectionInputInstance(field) &&
isStringFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateStringFieldCollectionValue(field.value, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isIntegerFieldCollectionInputInstance(field) &&
isIntegerFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
} else if (
field.value &&
isFloatFieldCollectionInputInstance(field) &&
isFloatFieldCollectionInputTemplate(fieldTemplate)
) {
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
reasons.push(...errors.map((error) => ({ prefix, content: error })));
}
});
});
@@ -491,17 +547,97 @@ export const selectPromptsCount = createSelector(
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
);
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
// The batch size is the product of all batch nodes' collection sizes
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
return batchSize;
}
// If the batch size is not set, default to 1
batchSize = batchSize || 1;
// Multiply the batch size by the number of images in the batch
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
const getBatchCollectionSize = (batchNode: InvocationNode) => {
if (batchNode.data.type === 'image_batch') {
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
return batchNode.data.inputs.images.value?.length ?? 0;
} else if (batchNode.data.type === 'string_batch') {
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
return batchNode.data.inputs.strings.value?.length ?? 0;
} else if (batchNode.data.type === 'float_batch') {
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
return resolveNumberFieldCollectionValue(batchNode.data.inputs.floats)?.length ?? 0;
} else if (batchNode.data.type === 'integer_batch') {
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
return resolveNumberFieldCollectionValue(batchNode.data.inputs.integers)?.length ?? 0;
}
return 0;
};
return batchSize;
}, 0)
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
return nodes
.filter(isInvocationNode)
.filter(isBatchNode)
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
.map(getBatchCollectionSize);
});
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
export const selectWorkflowsBatchSize = createSelector(
selectUngroupedBatchSizes,
selectGroup1BatchSizes,
selectGroup2BatchSizes,
selectGroup3BatchSizes,
selectGroup4BatchSizes,
selectGroup5BatchSizes,
(
ungroupedBatchSizes,
group1BatchSizes,
group2BatchSizes,
group3BatchSizes,
group4BatchSizes,
group5BatchSizes
): number | 'INVALID' | 'NO_BATCHES' => {
// All batch nodes _must_ have a populated collection
const allBatchSizes = [
...ungroupedBatchSizes,
...group1BatchSizes,
...group2BatchSizes,
...group3BatchSizes,
...group4BatchSizes,
...group5BatchSizes,
];
// There are no batch nodes
if (allBatchSizes.length === 0) {
return 'NO_BATCHES';
}
// All batch nodes must have a populated collection
if (allBatchSizes.some((size) => size === 0)) {
return 'INVALID';
}
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
// Ignore groups with no batch nodes
if (group.length === 0) {
continue;
}
// Grouped batch nodes must have the same collection size
if (group.some((size) => size !== group[0])) {
return 'INVALID';
}
}
// Total batch size = product of all ungrouped batches and each grouped batch
const totalBatchSize = [
...ungroupedBatchSizes,
// In case of no batch nodes in a group, fall back to 1 for the product calculation
group1BatchSizes[0] ?? 1,
group2BatchSizes[0] ?? 1,
group3BatchSizes[0] ?? 1,
group4BatchSizes[0] ?? 1,
group5BatchSizes[0] ?? 1,
].reduce((acc, size) => acc * size, 1);
return totalBatchSize;
}
);

File diff suppressed because one or more lines are too long

View File

@@ -189,6 +189,26 @@ def test_cannot_create_bad_batch_items_type(batch_graph):
)
def test_number_type_interop(batch_graph):
# integers and floats can be mixed, should not throw an error
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=[1, 1.5]),
]
],
)
Batch(
graph=batch_graph,
data=[
[
BatchDatum(node_path="1", field_name="prompt", items=[1.5, 1]),
]
],
)
def test_cannot_create_bad_batch_unique_ids(batch_graph):
with pytest.raises(ValidationError, match="Each batch data must have unique node_id and field_name"):
Batch(