mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-15 06:18:03 -05:00
Compare commits
65 Commits
v5.9.1
...
psyche/ref
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6df3e9f960 | ||
|
|
42d59b961e | ||
|
|
e8eac3d259 | ||
|
|
74a4197398 | ||
|
|
ba1701d734 | ||
|
|
10a77d22ef | ||
|
|
54dbc16cc9 | ||
|
|
a95cc979a9 | ||
|
|
b23abba8a6 | ||
|
|
d1ded55d8d | ||
|
|
d81cbd0a14 | ||
|
|
fa21f0887d | ||
|
|
4ecb4e8929 | ||
|
|
ede13f7882 | ||
|
|
77472a2f0c | ||
|
|
029c1fb8d9 | ||
|
|
100b151f84 | ||
|
|
a7326e3ad4 | ||
|
|
775bb276b2 | ||
|
|
977c2668e8 | ||
|
|
515ff485fc | ||
|
|
87cbc8ad45 | ||
|
|
2c250c29e8 | ||
|
|
9db69782e9 | ||
|
|
2499cf0c52 | ||
|
|
febc9615fc | ||
|
|
b6aae16471 | ||
|
|
f27250a4a7 | ||
|
|
0b95319dfa | ||
|
|
388580efb3 | ||
|
|
9ac30bd2a5 | ||
|
|
c3d2eb5426 | ||
|
|
789eb1fff5 | ||
|
|
fb5af7a4b7 | ||
|
|
eea29863a0 | ||
|
|
fab0af4d77 | ||
|
|
420c1d2874 | ||
|
|
d980a87e25 | ||
|
|
d02a8a9b62 | ||
|
|
084228c162 | ||
|
|
366ac86cbe | ||
|
|
55423ad1d6 | ||
|
|
11f17e3ea0 | ||
|
|
2d145871d9 | ||
|
|
aace8366d6 | ||
|
|
77555615bc | ||
|
|
4d9b35e8bd | ||
|
|
3cecc25d6c | ||
|
|
5a610dd00c | ||
|
|
2d8443cc21 | ||
|
|
1ed70bb21e | ||
|
|
e7a61c86f1 | ||
|
|
7d1f38560b | ||
|
|
2fcca151e7 | ||
|
|
c83caed552 | ||
|
|
a1687fafdd | ||
|
|
41a36f2701 | ||
|
|
8807248a1c | ||
|
|
14a605c1e1 | ||
|
|
edc39581ba | ||
|
|
de21cf1383 | ||
|
|
035508e2ee | ||
|
|
1b739b4f86 | ||
|
|
32f65937af | ||
|
|
6cb87a86c8 |
118
invokeai/app/invocations/batch.py
Normal file
118
invokeai/app/invocations/batch.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Literal
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
Classification,
|
||||
invocation,
|
||||
)
|
||||
from invokeai.app.invocations.fields import (
|
||||
ImageField,
|
||||
Input,
|
||||
InputField,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import FloatOutput, ImageOutput, IntegerOutput, StringOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
|
||||
BATCH_GROUP_IDS = Literal[
|
||||
"None",
|
||||
"Group 1",
|
||||
"Group 2",
|
||||
"Group 3",
|
||||
"Group 4",
|
||||
"Group 5",
|
||||
]
|
||||
|
||||
|
||||
class NotExecutableNodeError(Exception):
|
||||
def __init__(self, message: str = "This class should never be executed or instantiated directly."):
|
||||
super().__init__(message)
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseBatchInvocation(BaseInvocation):
|
||||
batch_group_id: BATCH_GROUP_IDS = InputField(
|
||||
default="None",
|
||||
description="The ID of this batch node's group. If provided, all batch nodes in with the same ID will be 'zipped' before execution, and all nodes' collections must be of the same size.",
|
||||
input=Input.Direct,
|
||||
title="Batch Group",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(
|
||||
default=[], min_length=1, description="The images to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"string_batch",
|
||||
title="String Batch",
|
||||
tags=["primitives", "string", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class StringBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each string in the batch."""
|
||||
|
||||
strings: list[str] = InputField(
|
||||
default=[], min_length=1, description="The strings to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> StringOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"integer_batch",
|
||||
title="Integer Batch",
|
||||
tags=["primitives", "integer", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class IntegerBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each integer in the batch."""
|
||||
|
||||
integers: list[int] = InputField(
|
||||
default=[], min_length=1, description="The integers to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IntegerOutput:
|
||||
raise NotExecutableNodeError()
|
||||
|
||||
|
||||
@invocation(
|
||||
"float_batch",
|
||||
title="Float Batch",
|
||||
tags=["primitives", "float", "number", "batch", "special"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class FloatBatchInvocation(BaseBatchInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each float in the batch."""
|
||||
|
||||
floats: list[float] = InputField(
|
||||
default=[], min_length=1, description="The floats to batch over", input=Input.Direct
|
||||
)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FloatOutput:
|
||||
raise NotExecutableNodeError()
|
||||
@@ -7,7 +7,6 @@ import torch
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
@@ -539,23 +538,3 @@ class BoundingBoxInvocation(BaseInvocation):
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@invocation(
|
||||
"image_batch",
|
||||
title="Image Batch",
|
||||
tags=["primitives", "image", "batch", "internal"],
|
||||
category="primitives",
|
||||
version="1.0.0",
|
||||
classification=Classification.Special,
|
||||
)
|
||||
class ImageBatchInvocation(BaseInvocation):
|
||||
"""Create a batched generation, where the workflow is executed once for each image in the batch."""
|
||||
|
||||
images: list[ImageField] = InputField(min_length=1, description="The images to batch over", input=Input.Direct)
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
raise NotImplementedError("This class should never be executed or instantiated directly.")
|
||||
|
||||
@@ -108,8 +108,16 @@ class Batch(BaseModel):
|
||||
return v
|
||||
for batch_data_list in v:
|
||||
for datum in batch_data_list:
|
||||
if not datum.items:
|
||||
continue
|
||||
|
||||
# Special handling for numbers - they can be mixed
|
||||
# TODO(psyche): Update BatchDatum to have a `type` field to specify the type of the items, then we can have strict float and int fields
|
||||
if all(isinstance(item, (int, float)) for item in datum.items):
|
||||
continue
|
||||
|
||||
# Get the type of the first item in the list
|
||||
first_item_type = type(datum.items[0]) if datum.items else None
|
||||
first_item_type = type(datum.items[0])
|
||||
for item in datum.items:
|
||||
if type(item) is not first_item_type:
|
||||
raise BatchItemsTypeError("All items in a batch must have the same type")
|
||||
|
||||
@@ -177,7 +177,11 @@
|
||||
"none": "None",
|
||||
"new": "New",
|
||||
"generating": "Generating",
|
||||
"warnings": "Warnings"
|
||||
"warnings": "Warnings",
|
||||
"start": "Start",
|
||||
"count": "Count",
|
||||
"step": "Step",
|
||||
"values": "Values"
|
||||
},
|
||||
"hrf": {
|
||||
"hrf": "High Resolution Fix",
|
||||
@@ -850,7 +854,14 @@
|
||||
"defaultVAE": "Default VAE"
|
||||
},
|
||||
"nodes": {
|
||||
"noBatchGroup": "no group",
|
||||
"generator": "Generator",
|
||||
"generatedValues": "Generated Values",
|
||||
"commitValues": "Commit Values",
|
||||
"addValue": "Add Value",
|
||||
"addNode": "Add Node",
|
||||
"lockLinearView": "Lock Linear View",
|
||||
"unlockLinearView": "Unlock Linear View",
|
||||
"addNodeToolTip": "Add Node (Shift+A, Space)",
|
||||
"addLinearView": "Add to Linear View",
|
||||
"animatedEdges": "Animated Edges",
|
||||
@@ -1024,11 +1035,21 @@
|
||||
"addingImagesTo": "Adding images to",
|
||||
"invoke": "Invoke",
|
||||
"missingFieldTemplate": "Missing field template",
|
||||
"missingInputForField": "{{nodeLabel}} -> {{fieldLabel}}: missing input",
|
||||
"missingInputForField": "missing input",
|
||||
"missingNodeTemplate": "Missing node template",
|
||||
"collectionEmpty": "{{nodeLabel}} -> {{fieldLabel}} empty collection",
|
||||
"collectionTooFewItems": "{{nodeLabel}} -> {{fieldLabel}}: too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "{{nodeLabel}} -> {{fieldLabel}}: too many items, maximum {{maxItems}}",
|
||||
"collectionEmpty": "empty collection",
|
||||
"invalidBatchConfiguration": "Invalid batch configuration",
|
||||
"batchNodeNotConnected": "Batch node not connected: {{label}}",
|
||||
"collectionTooFewItems": "too few items, minimum {{minItems}}",
|
||||
"collectionTooManyItems": "too many items, maximum {{maxItems}}",
|
||||
"collectionStringTooLong": "too long, max {{maxLength}}",
|
||||
"collectionStringTooShort": "too short, min {{minLength}}",
|
||||
"collectionNumberGTMax": "{{value}} > {{maximum}} (inc max)",
|
||||
"collectionNumberLTMin": "{{value}} < {{minimum}} (inc min)",
|
||||
"collectionNumberGTExclusiveMax": "{{value}} >= {{exclusiveMaximum}} (exc max)",
|
||||
"collectionNumberLTExclusiveMin": "{{value}} <= {{exclusiveMinimum}} (exc min)",
|
||||
"collectionNumberNotMultipleOf": "{{value}} not multiple of {{multipleOf}}",
|
||||
"batchNodeCollectionSizeMismatch": "Collection size mismatch on Batch {{batchGroupId}}",
|
||||
"noModelSelected": "No model selected",
|
||||
"noT5EncoderModelSelected": "No T5 Encoder model selected for FLUX generation",
|
||||
"noFLUXVAEModelSelected": "No VAE model selected for FLUX generation",
|
||||
|
||||
@@ -2,10 +2,19 @@ import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequested } from 'app/store/actions';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputInstance,
|
||||
} from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
|
||||
import { groupBy } from 'lodash-es';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import type { Batch, BatchConfig } from 'services/api/types';
|
||||
|
||||
@@ -33,28 +42,140 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
// Skip edges from batch nodes - these should not be in the graph, they exist only in the UI
|
||||
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
|
||||
for (const node of imageBatchNodes) {
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromImageBatch) {
|
||||
const batchNodes = nodes.nodes.filter(isInvocationNode).filter(isBatchNode);
|
||||
|
||||
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
|
||||
const addProductBatchDataCollectionItem = (
|
||||
edges: InvocationNodeEdge[],
|
||||
items?: ImageField[] | string[] | number[]
|
||||
) => {
|
||||
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edges) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
batchDataCollectionItem.push({
|
||||
productBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: images.value,
|
||||
items,
|
||||
});
|
||||
}
|
||||
if (batchDataCollectionItem.length > 0) {
|
||||
data.push(batchDataCollectionItem);
|
||||
if (productBatchDataCollectionItems.length > 0) {
|
||||
data.push(productBatchDataCollectionItems);
|
||||
}
|
||||
};
|
||||
|
||||
// Then, we will create a batch data collection item for each group
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
const addZippedBatchDataCollectionItem = (
|
||||
edges: InvocationNodeEdge[],
|
||||
items?: ImageField[] | string[] | number[]
|
||||
) => {
|
||||
for (const edge of edges) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
zippedBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Grab image batch nodes for special handling
|
||||
const imageBatchNodes = batchNodes.filter((node) => node.data.type === 'image_batch');
|
||||
|
||||
for (const node of imageBatchNodes) {
|
||||
// Satisfy TS
|
||||
const images = node.data.inputs['images'];
|
||||
if (!isImageFieldCollectionInputInstance(images)) {
|
||||
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromImageBatch, images.value);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab string batch nodes for special handling
|
||||
const stringBatchNodes = batchNodes.filter((node) => node.data.type === 'string_batch');
|
||||
for (const node of stringBatchNodes) {
|
||||
// Satisfy TS
|
||||
const strings = node.data.inputs['strings'];
|
||||
if (!isStringFieldCollectionInputInstance(strings)) {
|
||||
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, strings.value);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab integer batch nodes for special handling
|
||||
const integerBatchNodes = batchNodes.filter((node) => node.data.type === 'integer_batch');
|
||||
for (const node of integerBatchNodes) {
|
||||
// Satisfy TS
|
||||
const integers = node.data.inputs['integers'];
|
||||
if (!isIntegerFieldCollectionInputInstance(integers)) {
|
||||
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
|
||||
break;
|
||||
}
|
||||
if (!integers.value) {
|
||||
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(integers);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
// Grab float batch nodes for special handling
|
||||
const floatBatchNodes = batchNodes.filter((node) => node.data.type === 'float_batch');
|
||||
for (const node of floatBatchNodes) {
|
||||
// Satisfy TS
|
||||
const floats = node.data.inputs['floats'];
|
||||
if (!isFloatFieldCollectionInputInstance(floats)) {
|
||||
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
|
||||
break;
|
||||
}
|
||||
if (!floats.value) {
|
||||
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
|
||||
break;
|
||||
}
|
||||
|
||||
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
|
||||
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(floats);
|
||||
if (batchGroupId !== 'None') {
|
||||
addZippedBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
} else {
|
||||
addProductBatchDataCollectionItem(edgesFromStringBatch, resolvedValue);
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,8 +166,10 @@ export const createStore = (uniqueStoreKey?: string, persist = true) =>
|
||||
reducer: rememberedRootReducer,
|
||||
middleware: (getDefaultMiddleware) =>
|
||||
getDefaultMiddleware({
|
||||
serializableCheck: import.meta.env.MODE === 'development',
|
||||
immutableCheck: import.meta.env.MODE === 'development',
|
||||
serializableCheck: false,
|
||||
immutableCheck: false,
|
||||
// serializableCheck: import.meta.env.MODE === 'development',
|
||||
// immutableCheck: import.meta.env.MODE === 'development',
|
||||
})
|
||||
.concat(api.middleware)
|
||||
.concat(dynamicMiddlewares)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { getPrefixedId } from 'features/controlLayers/konva/util';
|
||||
import type {
|
||||
@@ -9,7 +10,6 @@ import { selectComparisonImages } from 'features/gallery/components/ImageViewer/
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import {
|
||||
addImagesToBoard,
|
||||
addImagesToNodeImageFieldCollectionAction,
|
||||
createNewCanvasEntityFromImage,
|
||||
removeImagesFromBoard,
|
||||
replaceCanvasEntityObjectsWithImage,
|
||||
@@ -19,10 +19,14 @@ import {
|
||||
setRegionalGuidanceReferenceImage,
|
||||
setUpscaleInitialImage,
|
||||
} from 'features/imageActions/actions';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { JsonObject } from 'type-fest';
|
||||
|
||||
const log = logger('dnd');
|
||||
|
||||
type RecordUnknown = Record<string | symbol, unknown>;
|
||||
|
||||
type DndData<
|
||||
@@ -268,15 +272,27 @@ export const addImagesToNodeImageFieldCollectionDndTarget: DndTarget<
|
||||
}
|
||||
|
||||
const { fieldIdentifier } = targetData.payload;
|
||||
const imageDTOs: ImageDTO[] = [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
imageDTOs.push(sourceData.payload.imageDTO);
|
||||
} else {
|
||||
imageDTOs.push(...sourceData.payload.imageDTOs);
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
addImagesToNodeImageFieldCollectionAction({ fieldIdentifier, imageDTOs, dispatch, getState });
|
||||
const newValue = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
|
||||
if (singleImageDndSource.typeGuard(sourceData)) {
|
||||
newValue.push({ image_name: sourceData.payload.imageDTO.image_name });
|
||||
} else {
|
||||
newValue.push(...sourceData.payload.imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
}
|
||||
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: newValue }));
|
||||
},
|
||||
};
|
||||
//#endregion
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
import { deepClone } from 'common/util/deepClone';
|
||||
import { selectDefaultIPAdapter } from 'features/controlLayers/hooks/addLayerHooks';
|
||||
@@ -31,19 +30,15 @@ import { imageDTOToImageObject, imageDTOToImageWithDims, initialControlNet } fro
|
||||
import { calculateNewSize } from 'features/controlLayers/util/getScaledBoundingBoxDimensions';
|
||||
import { imageToCompareChanged, selectionChanged } from 'features/gallery/store/gallerySlice';
|
||||
import type { BoardId } from 'features/gallery/store/types';
|
||||
import { fieldImageCollectionValueChanged, fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { type FieldIdentifier, isImageFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { fieldImageValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { FieldIdentifier } from 'features/nodes/types/field';
|
||||
import { upscaleInitialImageChanged } from 'features/parameters/store/upscaleSlice';
|
||||
import { getOptimalDimension } from 'features/parameters/util/optimalDimension';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import { imagesApi } from 'services/api/endpoints/images';
|
||||
import type { ImageDTO } from 'services/api/types';
|
||||
import type { Equals } from 'tsafe';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const setGlobalReferenceImage = (arg: {
|
||||
imageDTO: ImageDTO;
|
||||
entityIdentifier: CanvasEntityIdentifier<'reference_image'>;
|
||||
@@ -77,54 +72,6 @@ export const setNodeImageFieldImage = (arg: {
|
||||
dispatch(fieldImageValueChanged({ ...fieldIdentifier, value: imageDTO }));
|
||||
};
|
||||
|
||||
export const addImagesToNodeImageFieldCollectionAction = (arg: {
|
||||
imageDTOs: ImageDTO[];
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageDTOs, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to add images to a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
images.push(...imageDTOs.map(({ image_name }) => ({ image_name })));
|
||||
const uniqueImages = uniqBy(images, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const removeImageFromNodeImageFieldCollectionAction = (arg: {
|
||||
imageName: string;
|
||||
fieldIdentifier: FieldIdentifier;
|
||||
dispatch: AppDispatch;
|
||||
getState: () => RootState;
|
||||
}) => {
|
||||
const { imageName, fieldIdentifier, dispatch, getState } = arg;
|
||||
const fieldInputInstance = selectFieldInputInstance(
|
||||
selectNodesSlice(getState()),
|
||||
fieldIdentifier.nodeId,
|
||||
fieldIdentifier.fieldName
|
||||
);
|
||||
|
||||
if (!isImageFieldCollectionInputInstance(fieldInputInstance)) {
|
||||
log.warn({ fieldIdentifier }, 'Attempted to remove image from a non-image field collection');
|
||||
return;
|
||||
}
|
||||
|
||||
const images = fieldInputInstance.value ? [...fieldInputInstance.value] : [];
|
||||
const imagesWithoutTheImageToRemove = images.filter((image) => image.image_name !== imageName);
|
||||
const uniqueImages = uniqBy(imagesWithoutTheImageToRemove, 'image_name');
|
||||
dispatch(fieldImageCollectionValueChanged({ ...fieldIdentifier, value: uniqueImages }));
|
||||
};
|
||||
|
||||
export const setComparisonImage = (arg: { imageDTO: ImageDTO; dispatch: AppDispatch }) => {
|
||||
const { imageDTO, dispatch } = arg;
|
||||
dispatch(imageToCompareChanged(imageDTO));
|
||||
|
||||
@@ -43,7 +43,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
{fieldNames.connectionFields.map((fieldName, i) => (
|
||||
<GridItem gridColumnStart={1} gridRowStart={i + 1} key={`${nodeId}.${fieldName}.input-field`}>
|
||||
<InvocationInputFieldCheck nodeId={nodeId} fieldName={fieldName}>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
</GridItem>
|
||||
))}
|
||||
@@ -59,7 +59,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
{fieldNames.missingFields.map((fieldName) => (
|
||||
@@ -68,7 +68,7 @@ const InvocationNode = ({ nodeId, isOpen, label, type, selected }: Props) => {
|
||||
nodeId={nodeId}
|
||||
fieldName={fieldName}
|
||||
>
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputField nodeId={nodeId} fieldName={fieldName} isLinearView={false} />
|
||||
</InvocationInputFieldCheck>
|
||||
))}
|
||||
</Flex>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { IconButton } from '@invoke-ai/ui-library';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import {
|
||||
selectWorkflowSlice,
|
||||
workflowExposedFieldAdded,
|
||||
@@ -19,7 +19,7 @@ type Props = {
|
||||
const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { t } = useTranslation();
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const selectIsExposed = useMemo(
|
||||
() =>
|
||||
createSelector(selectWorkflowSlice, (workflow) => {
|
||||
@@ -31,8 +31,11 @@ const FieldLinearViewToggle = ({ nodeId, fieldName }: Props) => {
|
||||
const isExposed = useAppSelector(selectIsExposed);
|
||||
|
||||
const handleExposeField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, value }));
|
||||
}, [dispatch, fieldName, nodeId, value]);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
dispatch(workflowExposedFieldAdded({ nodeId, fieldName, field }));
|
||||
}, [dispatch, field, fieldName, nodeId]);
|
||||
|
||||
const handleUnexposeField = useCallback(() => {
|
||||
dispatch(workflowExposedFieldRemoved({ nodeId, fieldName }));
|
||||
|
||||
@@ -14,9 +14,10 @@ import { InputFieldWrapper } from './InputFieldWrapper';
|
||||
interface Props {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
}
|
||||
|
||||
const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
const InputField = ({ nodeId, fieldName, isLinearView }: Props) => {
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const isInvalid = useFieldIsInvalid(nodeId, fieldName);
|
||||
@@ -69,12 +70,12 @@ const InputField = ({ nodeId, fieldName }: Props) => {
|
||||
px={2}
|
||||
>
|
||||
<Flex flexDir="column" w="full" gap={1} onMouseEnter={onMouseEnter} onMouseLeave={onMouseLeave}>
|
||||
<Flex gap={1}>
|
||||
<Flex gap={1} alignItems="center">
|
||||
<EditableFieldTitle nodeId={nodeId} fieldName={fieldName} kind="inputs" isInvalid={isInvalid} withTooltip />
|
||||
{isHovered && <FieldResetToDefaultValueButton nodeId={nodeId} fieldName={fieldName} />}
|
||||
{isHovered && <FieldLinearViewToggle nodeId={nodeId} fieldName={fieldName} />}
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={isLinearView} />
|
||||
</Flex>
|
||||
</FormControl>
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { ImageFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ImageFieldCollectionInputComponent';
|
||||
import ModelIdentifierFieldInputComponent from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/ModelIdentifierFieldInputComponent';
|
||||
import { NumberFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/NumberFieldCollectionInputComponent';
|
||||
import { StringFieldCollectionInputComponent } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/StringFieldCollectionInputComponent';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import {
|
||||
@@ -21,6 +23,8 @@ import {
|
||||
isControlNetModelFieldInputTemplate,
|
||||
isEnumFieldInputInstance,
|
||||
isEnumFieldInputTemplate,
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isFloatFieldInputInstance,
|
||||
isFloatFieldInputTemplate,
|
||||
isFluxMainModelFieldInputInstance,
|
||||
@@ -31,6 +35,8 @@ import {
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isImageFieldInputInstance,
|
||||
isImageFieldInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isIntegerFieldInputInstance,
|
||||
isIntegerFieldInputTemplate,
|
||||
isIPAdapterModelFieldInputInstance,
|
||||
@@ -51,6 +57,8 @@ import {
|
||||
isSDXLRefinerModelFieldInputTemplate,
|
||||
isSpandrelImageToImageModelFieldInputInstance,
|
||||
isSpandrelImageToImageModelFieldInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
isStringFieldInputInstance,
|
||||
isStringFieldInputTemplate,
|
||||
isT2IAdapterModelFieldInputInstance,
|
||||
@@ -91,96 +99,285 @@ import VAEModelFieldInputComponent from './inputs/VAEModelFieldInputComponent';
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
const InputFieldRenderer = ({ nodeId, fieldName, isLinearView }: InputFieldProps) => {
|
||||
const fieldInstance = useFieldInputInstance(nodeId, fieldName);
|
||||
const fieldTemplate = useFieldInputTemplate(nodeId, fieldName);
|
||||
|
||||
if (isStringFieldCollectionInputInstance(fieldInstance) && isStringFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<StringFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isStringFieldInputInstance(fieldInstance) && isStringFieldInputTemplate(fieldTemplate)) {
|
||||
return <StringFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<StringFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isBooleanFieldInputInstance(fieldInstance) && isBooleanFieldInputTemplate(fieldTemplate)) {
|
||||
return <BooleanFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<BooleanFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
(isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) ||
|
||||
(isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate))
|
||||
) {
|
||||
return <NumberFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
if (isIntegerFieldInputInstance(fieldInstance) && isIntegerFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldInputInstance(fieldInstance) && isFloatFieldInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(fieldInstance) && isIntegerFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(fieldInstance) && isFloatFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return (
|
||||
<NumberFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isEnumFieldInputInstance(fieldInstance) && isEnumFieldInputTemplate(fieldTemplate)) {
|
||||
return <EnumFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<EnumFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isImageFieldCollectionInputInstance(fieldInstance) && isImageFieldCollectionInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldCollectionInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ImageFieldCollectionInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isImageFieldInputInstance(fieldInstance) && isImageFieldInputTemplate(fieldTemplate)) {
|
||||
return <ImageFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ImageFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isBoardFieldInputInstance(fieldInstance) && isBoardFieldInputTemplate(fieldTemplate)) {
|
||||
return <BoardFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<BoardFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isMainModelFieldInputInstance(fieldInstance) && isMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isModelIdentifierFieldInputInstance(fieldInstance) && isModelIdentifierFieldInputTemplate(fieldTemplate)) {
|
||||
return <ModelIdentifierFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ModelIdentifierFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSDXLRefinerModelFieldInputInstance(fieldInstance) && isSDXLRefinerModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <RefinerModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<RefinerModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isVAEModelFieldInputInstance(fieldInstance) && isVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <VAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<VAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<T5EncoderModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isCLIPLEmbedModelFieldInputInstance(fieldInstance) && isCLIPLEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPLEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPLEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isCLIPGEmbedModelFieldInputInstance(fieldInstance) && isCLIPGEmbedModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <CLIPGEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<CLIPGEmbedModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isControlLoRAModelFieldInputInstance(fieldInstance) && isControlLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <ControlLoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ControlLoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<FluxVAEModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<LoRAModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isControlNetModelFieldInputInstance(fieldInstance) && isControlNetModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <ControlNetModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ControlNetModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isIPAdapterModelFieldInputInstance(fieldInstance) && isIPAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <IPAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<IPAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isT2IAdapterModelFieldInputInstance(fieldInstance) && isT2IAdapterModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <T2IAdapterModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<T2IAdapterModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
@@ -192,28 +389,64 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isColorFieldInputInstance(fieldInstance) && isColorFieldInputTemplate(fieldTemplate)) {
|
||||
return <ColorFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<ColorFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isFluxMainModelFieldInputInstance(fieldInstance) && isFluxMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <FluxMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<FluxMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSD3MainModelFieldInputInstance(fieldInstance) && isSD3MainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SD3MainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SD3MainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSDXLMainModelFieldInputInstance(fieldInstance) && isSDXLMainModelFieldInputTemplate(fieldTemplate)) {
|
||||
return <SDXLMainModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SDXLMainModelFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (isSchedulerFieldInputInstance(fieldInstance) && isSchedulerFieldInputTemplate(fieldTemplate)) {
|
||||
return <SchedulerFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
|
||||
return (
|
||||
<SchedulerFieldInputComponent
|
||||
nodeId={nodeId}
|
||||
field={fieldInstance}
|
||||
fieldTemplate={fieldTemplate}
|
||||
isLinearView={isLinearView}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (fieldTemplate) {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { Flex, FormControl, FormLabel } from '@invoke-ai/ui-library';
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectInvocationNode, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
@@ -18,7 +18,7 @@ export const InvocationInputFieldCheck = memo(({ nodeId, fieldName, children }:
|
||||
const templates = useStore($templates);
|
||||
const selector = useMemo(
|
||||
() =>
|
||||
createSelector(selectNodesSlice, (nodesSlice) => {
|
||||
createMemoizedSelector(selectNodesSlice, (nodesSlice) => {
|
||||
const node = selectInvocationNode(nodesSlice, nodeId);
|
||||
const instance = node.data.inputs[fieldName];
|
||||
const template = templates[node.data.type];
|
||||
|
||||
@@ -97,7 +97,11 @@ const LinearViewFieldInternal = ({ fieldIdentifier }: Props) => {
|
||||
icon={<PiTrashSimpleBold />}
|
||||
/>
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={fieldIdentifier.nodeId} fieldName={fieldIdentifier.fieldName} />
|
||||
<InputFieldRenderer
|
||||
nodeId={fieldIdentifier.nodeId}
|
||||
fieldName={fieldIdentifier.fieldName}
|
||||
isLinearView={true}
|
||||
/>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<DndListDropIndicator dndState={dndListState} />
|
||||
|
||||
@@ -26,7 +26,7 @@ const EnumFieldInputComponent = (props: FieldComponentProps<EnumFieldInputInstan
|
||||
);
|
||||
|
||||
return (
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value}>
|
||||
<Select className="nowheel nodrag" onChange={handleValueChanged} value={field.value} size="sm">
|
||||
{fieldTemplate.options.map((option) => (
|
||||
<option key={option} value={option}>
|
||||
{fieldTemplate.ui_choice_labels ? fieldTemplate.ui_choice_labels[option] : option}
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
import { CompositeNumberInput, Flex, FormControl, FormLabel, IconButton } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
type FloatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiArrowCounterClockwiseBold } from 'react-icons/pi';
|
||||
|
||||
type FloatRangeGeneratorProps = {
|
||||
state: FloatRangeStartStepCountGenerator;
|
||||
onChange: (state: FloatRangeStartStepCountGenerator) => void;
|
||||
};
|
||||
|
||||
export const FloatRangeGenerator = memo(({ state, onChange }: FloatRangeGeneratorProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onChangeStart = useCallback(
|
||||
(start: number) => {
|
||||
onChange({ ...state, start });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeStep = useCallback(
|
||||
(step: number) => {
|
||||
onChange({ ...state, step });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
const onChangeCount = useCallback(
|
||||
(count: number) => {
|
||||
onChange({ ...state, count });
|
||||
},
|
||||
[onChange, state]
|
||||
);
|
||||
|
||||
const onReset = useCallback(() => {
|
||||
onChange(getDefaultFloatRangeStartStepCountGenerator());
|
||||
}, [onChange]);
|
||||
|
||||
return (
|
||||
<Flex gap={1} alignItems="flex-end" p={1}>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.start')}</FormLabel>
|
||||
<CompositeNumberInput value={state.start} onChange={onChangeStart} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.count')}</FormLabel>
|
||||
<CompositeNumberInput value={state.count} onChange={onChangeCount} min={1} max={Infinity} />
|
||||
</FormControl>
|
||||
<FormControl orientation="vertical" gap={1}>
|
||||
<FormLabel m={0}>{t('common.step')}</FormLabel>
|
||||
<CompositeNumberInput value={state.step} onChange={onChangeStep} min={-Infinity} max={Infinity} step={0.01} />
|
||||
</FormControl>
|
||||
<IconButton
|
||||
onClick={onReset}
|
||||
aria-label={t('common.reset')}
|
||||
icon={<PiArrowCounterClockwiseBold />}
|
||||
variant="ghost"
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
|
||||
FloatRangeGenerator.displayName = 'FloatRangeGenerator';
|
||||
@@ -10,9 +10,9 @@ import { addImagesToNodeImageFieldCollectionDndTarget } from 'features/dnd/dnd';
|
||||
import { DndDropTarget } from 'features/dnd/DndDropTarget';
|
||||
import { DndImage } from 'features/dnd/DndImage';
|
||||
import { DndImageIcon } from 'features/dnd/DndImageIcon';
|
||||
import { removeImageFromNodeImageFieldCollectionAction } from 'features/imageActions/actions';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldImageCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type { ImageField } from 'features/nodes/types/common';
|
||||
import type { ImageFieldCollectionInputInstance, ImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -61,15 +61,12 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
);
|
||||
|
||||
const onRemoveImage = useCallback(
|
||||
(imageName: string) => {
|
||||
removeImageFromNodeImageFieldCollectionAction({
|
||||
imageName,
|
||||
fieldIdentifier: { nodeId, fieldName: field.name },
|
||||
dispatch: store.dispatch,
|
||||
getState: store.getState,
|
||||
});
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldImageCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, nodeId, store.dispatch, store.getState]
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
return (
|
||||
@@ -90,7 +87,7 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
isError={isInvalid}
|
||||
onUpload={onUpload}
|
||||
fontSize={24}
|
||||
variant="outline"
|
||||
variant="ghost"
|
||||
/>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
@@ -102,9 +99,9 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(4, 1fr)" gap={1}>
|
||||
{field.value.map(({ image_name }) => (
|
||||
<GridItem key={image_name} position="relative" className="nodrag">
|
||||
<ImageGridItemContent imageName={image_name} onRemoveImage={onRemoveImage} />
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<ImageGridItemContent value={value} index={index} onRemoveImage={onRemoveImage} />
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
@@ -124,11 +121,11 @@ export const ImageFieldCollectionInputComponent = memo(
|
||||
ImageFieldCollectionInputComponent.displayName = 'ImageFieldCollectionInputComponent';
|
||||
|
||||
const ImageGridItemContent = memo(
|
||||
({ imageName, onRemoveImage }: { imageName: string; onRemoveImage: (imageName: string) => void }) => {
|
||||
const query = useGetImageDTOQuery(imageName);
|
||||
({ value, index, onRemoveImage }: { value: ImageField; index: number; onRemoveImage: (index: number) => void }) => {
|
||||
const query = useGetImageDTOQuery(value.image_name);
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveImage(imageName);
|
||||
}, [imageName, onRemoveImage]);
|
||||
onRemoveImage(index);
|
||||
}, [index, onRemoveImage]);
|
||||
|
||||
if (query.isLoading) {
|
||||
return <IAINoContentFallbackWithSpinner />;
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import {
|
||||
Button,
|
||||
CompositeNumberInput,
|
||||
Divider,
|
||||
Flex,
|
||||
FormControl,
|
||||
FormLabel,
|
||||
Grid,
|
||||
GridItem,
|
||||
IconButton,
|
||||
Switch,
|
||||
Text,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { NUMPY_RAND_MAX } from 'app/constants';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { FloatRangeGenerator } from 'features/nodes/components/flow/nodes/Invocation/fields/inputs/FloatRangeGenerator';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import {
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldNumberCollectionValueChanged,
|
||||
} from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { isNil, round } from 'lodash-es';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiLockSimpleFill, PiLockSimpleOpenFill, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const NumberFieldCollectionInputComponent = memo(
|
||||
(
|
||||
props:
|
||||
| FieldComponentProps<IntegerFieldCollectionInputInstance, IntegerFieldCollectionInputTemplate>
|
||||
| FieldComponentProps<FloatFieldCollectionInputInstance, FloatFieldCollectionInputTemplate>
|
||||
) => {
|
||||
const { nodeId, field, fieldTemplate, isLinearView } = props;
|
||||
const store = useAppStore();
|
||||
const { t } = useTranslation();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
const isIntegerField = useMemo(() => fieldTemplate.type.name === 'IntegerField', [fieldTemplate.type]);
|
||||
|
||||
const onRemoveNumber = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeNumber = useCallback(
|
||||
(index: number, value: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddNumber = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, 0] : [0];
|
||||
store.dispatch(fieldNumberCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
const min = useMemo(() => {
|
||||
let min = -NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.minimum)) {
|
||||
min = fieldTemplate.minimum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMinimum)) {
|
||||
min = fieldTemplate.exclusiveMinimum + 0.01;
|
||||
}
|
||||
return min;
|
||||
}, [fieldTemplate.exclusiveMinimum, fieldTemplate.minimum]);
|
||||
|
||||
const max = useMemo(() => {
|
||||
let max = NUMPY_RAND_MAX;
|
||||
if (!isNil(fieldTemplate.maximum)) {
|
||||
max = fieldTemplate.maximum;
|
||||
}
|
||||
if (!isNil(fieldTemplate.exclusiveMaximum)) {
|
||||
max = fieldTemplate.exclusiveMaximum - 0.01;
|
||||
}
|
||||
return max;
|
||||
}, [fieldTemplate.exclusiveMaximum, fieldTemplate.maximum]);
|
||||
|
||||
const step = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.1;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const fineStep = useMemo(() => {
|
||||
if (isNil(fieldTemplate.multipleOf)) {
|
||||
return isIntegerField ? 1 : 0.01;
|
||||
}
|
||||
return fieldTemplate.multipleOf;
|
||||
}, [fieldTemplate.multipleOf, isIntegerField]);
|
||||
|
||||
const toggleGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onChangeGenerator = useCallback(
|
||||
(generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator) => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorStateChanged({ nodeId, fieldName: field.name, generatorState }));
|
||||
},
|
||||
[field.name, nodeId, store]
|
||||
);
|
||||
|
||||
const onCommitGenerator = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionGeneratorCommitted({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const onToggleLockLinearView = useCallback(() => {
|
||||
store.dispatch(fieldNumberCollectionLockLinearViewToggled({ nodeId, fieldName: field.name }));
|
||||
}, [field.name, nodeId, store]);
|
||||
|
||||
const valuesAsString = useMemo(() => {
|
||||
const resolvedValue = resolveNumberFieldCollectionValue(field);
|
||||
return resolvedValue ? resolvedValue.map((val) => round(val, 2)).join(', ') : '';
|
||||
}, [field]);
|
||||
|
||||
const isLockedOnLinearView = !(field.lockLinearView && isLinearView);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="auto"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
p={1}
|
||||
sx={sx}
|
||||
data-error={isInvalid}
|
||||
borderRadius="base"
|
||||
flexDir="column"
|
||||
gap={1}
|
||||
>
|
||||
<Flex w="full" gap={2}>
|
||||
{!field.generator && (
|
||||
<Button onClick={onAddNumber} variant="ghost" flexGrow={1} size="sm">
|
||||
{t('nodes.addValue')}
|
||||
</Button>
|
||||
)}
|
||||
{field.generator && isLockedOnLinearView && (
|
||||
<Button
|
||||
tooltip={
|
||||
<Flex p={1} flexDir="column">
|
||||
<Text fontWeight="semibold">{t('nodes.generatedValues')}:</Text>
|
||||
<Text fontFamily="monospace">{valuesAsString}</Text>
|
||||
</Flex>
|
||||
}
|
||||
onClick={onCommitGenerator}
|
||||
variant="ghost"
|
||||
flexGrow={1}
|
||||
size="sm"
|
||||
>
|
||||
{t('nodes.commitValues')}
|
||||
</Button>
|
||||
)}
|
||||
{isLockedOnLinearView && (
|
||||
<FormControl w="min-content" pe={isLinearView ? 2 : undefined}>
|
||||
<FormLabel m={0}>{t('nodes.generator')}</FormLabel>
|
||||
<Switch onChange={toggleGenerator} isChecked={Boolean(field.generator)} size="sm" />
|
||||
</FormControl>
|
||||
)}
|
||||
{!isLinearView && (
|
||||
<IconButton
|
||||
onClick={onToggleLockLinearView}
|
||||
tooltip={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
aria-label={field.lockLinearView ? t('nodes.unlockLinearView') : t('nodes.lockLinearView')}
|
||||
icon={field.lockLinearView ? <PiLockSimpleFill /> : <PiLockSimpleOpenFill />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
)}
|
||||
</Flex>
|
||||
{!field.generator && field.value && field.value.length > 0 && (
|
||||
<>
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid gap={1} gridTemplateColumns="auto 1fr auto" alignItems="center">
|
||||
{field.value.map((value, index) => (
|
||||
<NumberListItemContent
|
||||
key={index}
|
||||
value={value}
|
||||
index={index}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
isIntegerField={isIntegerField}
|
||||
onRemoveNumber={onRemoveNumber}
|
||||
onChangeNumber={onChangeNumber}
|
||||
/>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</>
|
||||
)}
|
||||
{field.generator && field.generator.type === 'float-range-generator-start-step-count' && (
|
||||
<>
|
||||
{!(field.lockLinearView && isLinearView) && <Divider />}
|
||||
<FloatRangeGenerator state={field.generator} onChange={onChangeGenerator} />
|
||||
</>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
NumberFieldCollectionInputComponent.displayName = 'NumberFieldCollectionInputComponent';
|
||||
|
||||
type NumberListItemContentProps = {
|
||||
value: number;
|
||||
index: number;
|
||||
isIntegerField: boolean;
|
||||
min: number;
|
||||
max: number;
|
||||
step: number;
|
||||
fineStep: number;
|
||||
onRemoveNumber: (index: number) => void;
|
||||
onChangeNumber: (index: number, value: number) => void;
|
||||
};
|
||||
|
||||
const NumberListItemContent = memo(
|
||||
({
|
||||
value,
|
||||
index,
|
||||
isIntegerField,
|
||||
min,
|
||||
max,
|
||||
step,
|
||||
fineStep,
|
||||
onRemoveNumber,
|
||||
onChangeNumber,
|
||||
}: NumberListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveNumber(index);
|
||||
}, [index, onRemoveNumber]);
|
||||
const onChange = useCallback(
|
||||
(v: number) => {
|
||||
onChangeNumber(index, isIntegerField ? Math.floor(Number(v)) : Number(v));
|
||||
},
|
||||
[index, isIntegerField, onChangeNumber]
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<GridItem>
|
||||
<FormLabel ps={1} m={0}>
|
||||
{index + 1}.
|
||||
</FormLabel>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<CompositeNumberInput
|
||||
onChange={onChange}
|
||||
value={value}
|
||||
min={min}
|
||||
max={max}
|
||||
step={step}
|
||||
fineStep={fineStep}
|
||||
className="nodrag"
|
||||
flexGrow={1}
|
||||
/>
|
||||
</GridItem>
|
||||
<GridItem>
|
||||
<IconButton
|
||||
tabIndex={-1}
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.delete')}
|
||||
/>
|
||||
</GridItem>
|
||||
</>
|
||||
);
|
||||
}
|
||||
);
|
||||
NumberListItemContent.displayName = 'NumberListItemContent';
|
||||
@@ -0,0 +1,152 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Flex, Grid, GridItem, IconButton, Input } from '@invoke-ai/ui-library';
|
||||
import { useAppStore } from 'app/store/nanostores/store';
|
||||
import { getOverlayScrollbarsParams, overlayScrollbarsStyles } from 'common/components/OverlayScrollbars/constants';
|
||||
import { useFieldIsInvalid } from 'features/nodes/hooks/useFieldIsInvalid';
|
||||
import { fieldStringCollectionValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
StringFieldCollectionInputInstance,
|
||||
StringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { OverlayScrollbarsComponent } from 'overlayscrollbars-react';
|
||||
import type { ChangeEvent } from 'react';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiPlusBold, PiXBold } from 'react-icons/pi';
|
||||
|
||||
import type { FieldComponentProps } from './types';
|
||||
|
||||
const overlayscrollbarsOptions = getOverlayScrollbarsParams().options;
|
||||
|
||||
const sx = {
|
||||
borderWidth: 1,
|
||||
'&[data-error=true]': {
|
||||
borderColor: 'error.500',
|
||||
borderStyle: 'solid',
|
||||
},
|
||||
} satisfies SystemStyleObject;
|
||||
|
||||
export const StringFieldCollectionInputComponent = memo(
|
||||
(props: FieldComponentProps<StringFieldCollectionInputInstance, StringFieldCollectionInputTemplate>) => {
|
||||
const { nodeId, field } = props;
|
||||
const store = useAppStore();
|
||||
|
||||
const isInvalid = useFieldIsInvalid(nodeId, field.name);
|
||||
|
||||
const onRemoveString = useCallback(
|
||||
(index: number) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue.splice(index, 1);
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onChangeString = useCallback(
|
||||
(index: number, value: string) => {
|
||||
const newValue = field.value ? [...field.value] : [];
|
||||
newValue[index] = value;
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
},
|
||||
[field.name, field.value, nodeId, store]
|
||||
);
|
||||
|
||||
const onAddString = useCallback(() => {
|
||||
const newValue = field.value ? [...field.value, ''] : [''];
|
||||
store.dispatch(fieldStringCollectionValueChanged({ nodeId, fieldName: field.name, value: newValue }));
|
||||
}, [field.name, field.value, nodeId, store]);
|
||||
|
||||
return (
|
||||
<Flex
|
||||
className="nodrag"
|
||||
position="relative"
|
||||
w="full"
|
||||
h="full"
|
||||
maxH={64}
|
||||
alignItems="stretch"
|
||||
justifyContent="center"
|
||||
>
|
||||
{(!field.value || field.value.length === 0) && (
|
||||
<Box w="full" sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<IconButton
|
||||
w="full"
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
</Box>
|
||||
)}
|
||||
{field.value && field.value.length > 0 && (
|
||||
<Box w="full" h="auto" p={1} sx={sx} data-error={isInvalid} borderRadius="base">
|
||||
<OverlayScrollbarsComponent
|
||||
className="nowheel"
|
||||
defer
|
||||
style={overlayScrollbarsStyles}
|
||||
options={overlayscrollbarsOptions}
|
||||
>
|
||||
<Grid w="full" h="full" templateColumns="repeat(1, 1fr)" gap={1}>
|
||||
<IconButton
|
||||
onClick={onAddString}
|
||||
aria-label="Add Item"
|
||||
icon={<PiPlusBold />}
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
/>
|
||||
{field.value.map((value, index) => (
|
||||
<GridItem key={index} position="relative" className="nodrag">
|
||||
<StringListItemContent
|
||||
value={value}
|
||||
index={index}
|
||||
onRemoveString={onRemoveString}
|
||||
onChangeString={onChangeString}
|
||||
/>
|
||||
</GridItem>
|
||||
))}
|
||||
</Grid>
|
||||
</OverlayScrollbarsComponent>
|
||||
</Box>
|
||||
)}
|
||||
</Flex>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
StringFieldCollectionInputComponent.displayName = 'StringFieldCollectionInputComponent';
|
||||
|
||||
type StringListItemContentProps = {
|
||||
value: string;
|
||||
index: number;
|
||||
onRemoveString: (index: number) => void;
|
||||
onChangeString: (index: number, value: string) => void;
|
||||
};
|
||||
|
||||
const StringListItemContent = memo(({ value, index, onRemoveString, onChangeString }: StringListItemContentProps) => {
|
||||
const { t } = useTranslation();
|
||||
|
||||
const onClickRemove = useCallback(() => {
|
||||
onRemoveString(index);
|
||||
}, [index, onRemoveString]);
|
||||
const onChange = useCallback(
|
||||
(e: ChangeEvent<HTMLInputElement>) => {
|
||||
onChangeString(index, e.target.value);
|
||||
},
|
||||
[index, onChangeString]
|
||||
);
|
||||
return (
|
||||
<Flex alignItems="center" gap={1}>
|
||||
<Input size="xs" resize="none" value={value} onChange={onChange} />
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
alignSelf="stretch"
|
||||
onClick={onClickRemove}
|
||||
icon={<PiXBold />}
|
||||
aria-label={t('common.remove')}
|
||||
tooltip={t('common.remove')}
|
||||
/>
|
||||
</Flex>
|
||||
);
|
||||
});
|
||||
StringListItemContent.displayName = 'StringListItemContent';
|
||||
@@ -4,4 +4,5 @@ export type FieldComponentProps<V extends FieldInputInstance, T extends FieldInp
|
||||
nodeId: string;
|
||||
field: V;
|
||||
fieldTemplate: T;
|
||||
isLinearView: boolean;
|
||||
};
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import type { SystemStyleObject } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, EditablePreview, Flex, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import type { SystemStyleObject, TextProps } from '@invoke-ai/ui-library';
|
||||
import { Box, Editable, EditableInput, Flex, Text, useEditableControls } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { useBatchGroupColorToken } from 'features/nodes/hooks/useBatchGroupColorToken';
|
||||
import { useBatchGroupId } from 'features/nodes/hooks/useBatchGroupId';
|
||||
import { useNodeLabel } from 'features/nodes/hooks/useNodeLabel';
|
||||
import { useNodeTemplateTitle } from 'features/nodes/hooks/useNodeTemplateTitle';
|
||||
import { nodeLabelChanged } from 'features/nodes/store/nodesSlice';
|
||||
import { DRAG_HANDLE_CLASSNAME } from 'features/nodes/types/constants';
|
||||
import type { MouseEvent } from 'react';
|
||||
import { memo, useCallback, useEffect, useState } from 'react';
|
||||
import { memo, useCallback, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
type Props = {
|
||||
@@ -17,6 +19,8 @@ type Props = {
|
||||
const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const label = useNodeLabel(nodeId);
|
||||
const batchGroupId = useBatchGroupId(nodeId);
|
||||
const batchGroupColorToken = useBatchGroupColorToken(batchGroupId);
|
||||
const templateTitle = useNodeTemplateTitle(nodeId);
|
||||
const { t } = useTranslation();
|
||||
|
||||
@@ -29,6 +33,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
[dispatch, nodeId, title, templateTitle, label, t]
|
||||
);
|
||||
|
||||
const localTitleWithBatchGroupId = useMemo(() => {
|
||||
if (!batchGroupId) {
|
||||
return localTitle;
|
||||
}
|
||||
if (batchGroupId === 'None') {
|
||||
return `${localTitle} (${t('nodes.noBatchGroup')})`;
|
||||
}
|
||||
return `${localTitle} (${batchGroupId})`;
|
||||
}, [batchGroupId, localTitle, t]);
|
||||
|
||||
const handleChange = useCallback((newTitle: string) => {
|
||||
setLocalTitle(newTitle);
|
||||
}, []);
|
||||
@@ -50,7 +64,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
w="full"
|
||||
h="full"
|
||||
>
|
||||
<EditablePreview fontSize="sm" p={0} w="full" noOfLines={1} />
|
||||
<Preview
|
||||
fontSize="sm"
|
||||
p={0}
|
||||
w="full"
|
||||
noOfLines={1}
|
||||
color={batchGroupColorToken}
|
||||
fontWeight={batchGroupId ? 'semibold' : undefined}
|
||||
>
|
||||
{localTitleWithBatchGroupId}
|
||||
</Preview>
|
||||
<EditableInput className="nodrag" fontSize="sm" sx={editableInputStyles} />
|
||||
<EditableControls />
|
||||
</Editable>
|
||||
@@ -60,6 +83,16 @@ const NodeTitle = ({ nodeId, title }: Props) => {
|
||||
|
||||
export default memo(NodeTitle);
|
||||
|
||||
const Preview = (props: TextProps) => {
|
||||
const { isEditing } = useEditableControls();
|
||||
|
||||
if (isEditing) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <Text {...props} />;
|
||||
};
|
||||
|
||||
function EditableControls() {
|
||||
const { isEditing, getEditButtonProps } = useEditableControls();
|
||||
const handleDoubleClick = useCallback(
|
||||
|
||||
@@ -5,7 +5,7 @@ import NodeOpacitySlider from './NodeOpacitySlider';
|
||||
import ViewportControls from './ViewportControls';
|
||||
|
||||
const BottomLeftPanel = () => (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineStart={0}>
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineStart={2}>
|
||||
<ViewportControls />
|
||||
<NodeOpacitySlider />
|
||||
</Flex>
|
||||
|
||||
@@ -20,7 +20,7 @@ const MinimapPanel = () => {
|
||||
const shouldShowMinimapPanel = useAppSelector(selectShouldShowMinimapPanel);
|
||||
|
||||
return (
|
||||
<Flex gap={2} position="absolute" bottom={0} insetInlineEnd={0}>
|
||||
<Flex gap={2} position="absolute" bottom={2} insetInlineEnd={2}>
|
||||
{shouldShowMinimapPanel && (
|
||||
<ChakraMiniMap
|
||||
pannable
|
||||
|
||||
@@ -12,7 +12,7 @@ import { memo } from 'react';
|
||||
const TopCenterPanel = () => {
|
||||
const name = useAppSelector(selectWorkflowName);
|
||||
return (
|
||||
<Flex gap={2} top={0} left={0} right={0} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap={2} top={2} left={2} right={2} position="absolute" alignItems="flex-start" pointerEvents="none">
|
||||
<Flex gap="2">
|
||||
<AddNodeButton />
|
||||
<UpdateNodesButton />
|
||||
|
||||
@@ -46,7 +46,7 @@ const WorkflowFieldInternal = ({ nodeId, fieldName }: Props) => {
|
||||
</Flex>
|
||||
</Tooltip>
|
||||
</Flex>
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} />
|
||||
<InputFieldRenderer nodeId={nodeId} fieldName={fieldName} isLinearView={true} />
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupColorToken = (batchGroupId?: string) => {
|
||||
const batchGroupColorToken = useMemo(() => {
|
||||
switch (batchGroupId) {
|
||||
case 'Group 1':
|
||||
return 'invokeGreen.300';
|
||||
case 'Group 2':
|
||||
return 'invokeBlue.300';
|
||||
case 'Group 3':
|
||||
return 'invokePurple.200';
|
||||
case 'Group 4':
|
||||
return 'invokeRed.300';
|
||||
case 'Group 5':
|
||||
return 'invokeYellow.300';
|
||||
default:
|
||||
return undefined;
|
||||
}
|
||||
}, [batchGroupId]);
|
||||
|
||||
return batchGroupColorToken;
|
||||
};
|
||||
@@ -0,0 +1,19 @@
|
||||
import { useNode } from 'features/nodes/hooks/useNode';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useBatchGroupId = (nodeId: string) => {
|
||||
const node = useNode(nodeId);
|
||||
|
||||
const batchGroupId = useMemo(() => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
if (!isBatchNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data.inputs['batch_group_id']?.value as string;
|
||||
}, [node]);
|
||||
|
||||
return batchGroupId;
|
||||
};
|
||||
@@ -3,7 +3,21 @@ import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { useConnectionState } from 'features/nodes/hooks/useConnectionState';
|
||||
import { useFieldInputTemplate } from 'features/nodes/hooks/useFieldInputTemplate';
|
||||
import { selectFieldInputInstance, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import { useMemo } from 'react';
|
||||
|
||||
export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
@@ -35,13 +49,27 @@ export const useFieldIsInvalid = (nodeId: string, fieldName: string) => {
|
||||
}
|
||||
|
||||
// Else special handling for individual field types
|
||||
|
||||
if (isImageFieldCollectionInputInstance(field) && isImageFieldCollectionInputTemplate(template)) {
|
||||
// Image collections may have min or max item counts
|
||||
if (template.minItems !== undefined && field.value.length < template.minItems) {
|
||||
if (validateImageFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (template.maxItems !== undefined && field.value.length > template.maxItems) {
|
||||
if (isStringFieldCollectionInputInstance(field) && isStringFieldCollectionInputTemplate(template)) {
|
||||
if (validateStringFieldCollectionValue(field.value, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputTemplate(template)) {
|
||||
if (validateNumberFieldCollectionValue(field, template).length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useFieldValue } from 'features/nodes/hooks/useFieldValue';
|
||||
import { useFieldInputInstance } from 'features/nodes/hooks/useFieldInputInstance';
|
||||
import { fieldValueReset } from 'features/nodes/store/nodesSlice';
|
||||
import { selectWorkflowSlice } from 'features/nodes/store/workflowSlice';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { isEqual } from 'lodash-es';
|
||||
import { useCallback, useMemo } from 'react';
|
||||
|
||||
@@ -10,19 +11,38 @@ export const useFieldOriginalValue = (nodeId: string, fieldName: string) => {
|
||||
const dispatch = useAppDispatch();
|
||||
const selectOriginalExposedFieldValues = useMemo(
|
||||
() =>
|
||||
createSelector(
|
||||
selectWorkflowSlice,
|
||||
(workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)?.value
|
||||
createMemoizedSelector(selectWorkflowSlice, (workflow) =>
|
||||
workflow.originalExposedFieldValues.find((v) => v.nodeId === nodeId && v.fieldName === fieldName)
|
||||
),
|
||||
[nodeId, fieldName]
|
||||
);
|
||||
const originalValue = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const value = useFieldValue(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => !isEqual(value, originalValue), [value, originalValue]);
|
||||
const exposedField = useAppSelector(selectOriginalExposedFieldValues);
|
||||
const field = useFieldInputInstance(nodeId, fieldName);
|
||||
const isValueChanged = useMemo(() => {
|
||||
if (!field) {
|
||||
// Field is not found, so it is not changed
|
||||
return false;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field) && isFloatFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
if (isIntegerFieldCollectionInputInstance(field) && isIntegerFieldCollectionInputInstance(exposedField?.field)) {
|
||||
return !isEqual(field.generator, exposedField.field.generator);
|
||||
}
|
||||
return !isEqual(field.value, exposedField?.field.value);
|
||||
}, [field, exposedField]);
|
||||
const onReset = useCallback(() => {
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value: originalValue }));
|
||||
}, [dispatch, fieldName, nodeId, originalValue]);
|
||||
if (!exposedField) {
|
||||
return;
|
||||
}
|
||||
const { value } = exposedField.field;
|
||||
const generator =
|
||||
isIntegerFieldCollectionInputInstance(exposedField.field) ||
|
||||
isFloatFieldCollectionInputInstance(exposedField.field)
|
||||
? exposedField.field.generator
|
||||
: undefined;
|
||||
dispatch(fieldValueReset({ nodeId, fieldName, value, generator }));
|
||||
}, [dispatch, fieldName, nodeId, exposedField]);
|
||||
|
||||
return { originalValue, isValueChanged, onReset };
|
||||
return { originalValue: exposedField, isValueChanged, onReset };
|
||||
};
|
||||
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
FluxVAEModelFieldValue,
|
||||
ImageFieldCollectionValue,
|
||||
ImageFieldValue,
|
||||
IntegerFieldCollectionValue,
|
||||
IntegerFieldValue,
|
||||
IPAdapterModelFieldValue,
|
||||
LoRAModelFieldValue,
|
||||
@@ -28,12 +29,15 @@ import type {
|
||||
SDXLRefinerModelFieldValue,
|
||||
SpandrelImageToImageModelFieldValue,
|
||||
StatefulFieldValue,
|
||||
StringFieldCollectionValue,
|
||||
StringFieldValue,
|
||||
T2IAdapterModelFieldValue,
|
||||
T5EncoderModelFieldValue,
|
||||
VAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
zBoardFieldValue,
|
||||
zBooleanFieldValue,
|
||||
zCLIPEmbedModelFieldValue,
|
||||
@@ -43,10 +47,12 @@ import {
|
||||
zControlLoRAModelFieldValue,
|
||||
zControlNetModelFieldValue,
|
||||
zEnumFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFluxVAEModelFieldValue,
|
||||
zImageFieldCollectionValue,
|
||||
zImageFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zIntegerFieldValue,
|
||||
zIPAdapterModelFieldValue,
|
||||
zLoRAModelFieldValue,
|
||||
@@ -56,11 +62,22 @@ import {
|
||||
zSDXLRefinerModelFieldValue,
|
||||
zSpandrelImageToImageModelFieldValue,
|
||||
zStatefulFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zT2IAdapterModelFieldValue,
|
||||
zT5EncoderModelFieldValue,
|
||||
zVAEModelFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
FloatRangeStartStepCountGenerator,
|
||||
IntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
getDefaultFloatRangeStartStepCountGenerator,
|
||||
getDefaultIntegerRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import type { AnyNode, InvocationNodeEdge } from 'features/nodes/types/invocation';
|
||||
import { isInvocationNode, isNotesNode } from 'features/nodes/types/invocation';
|
||||
import { atom, computed } from 'nanostores';
|
||||
@@ -78,11 +95,22 @@ const initialNodesState: NodesState = {
|
||||
edges: [],
|
||||
};
|
||||
|
||||
type FieldValueAction<T extends FieldValue> = PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
}>;
|
||||
type FieldValueAction<T extends FieldValue, U = unknown> = PayloadAction<
|
||||
{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
value: T;
|
||||
} & U
|
||||
>;
|
||||
|
||||
const selectField = (state: NodesState, nodeId: string, fieldName: string) => {
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
return node.data?.inputs[fieldName];
|
||||
};
|
||||
|
||||
const fieldValueReducer = <T extends FieldValue>(
|
||||
state: NodesState,
|
||||
@@ -90,17 +118,24 @@ const fieldValueReducer = <T extends FieldValue>(
|
||||
schema: z.ZodTypeAny
|
||||
) => {
|
||||
const { nodeId, fieldName, value } = action.payload;
|
||||
const nodeIndex = state.nodes.findIndex((n) => n.id === nodeId);
|
||||
const node = state.nodes?.[nodeIndex];
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
const input = node.data?.inputs[fieldName];
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = schema.safeParse(value);
|
||||
if (!input || nodeIndex < 0 || !result.success) {
|
||||
if (!field || !result.success) {
|
||||
return;
|
||||
}
|
||||
input.value = result.data;
|
||||
field.value = result.data;
|
||||
// Special handling if the field value is being reset
|
||||
if (result.data === undefined) {
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultFloatRangeStartStepCountGenerator();
|
||||
}
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
if (field.lockLinearView && field.generator) {
|
||||
field.generator = getDefaultIntegerRangeStartStepCountGenerator();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
export const nodesSlice = createSlice({
|
||||
@@ -305,15 +340,123 @@ export const nodesSlice = createSlice({
|
||||
}
|
||||
node.data.notes = notes;
|
||||
},
|
||||
fieldValueReset: (state, action: FieldValueAction<StatefulFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStatefulFieldValue);
|
||||
fieldValueReset: (
|
||||
state,
|
||||
action: FieldValueAction<
|
||||
StatefulFieldValue,
|
||||
{ generator?: IntegerRangeStartStepCountGenerator | FloatRangeStartStepCountGenerator }
|
||||
>
|
||||
) => {
|
||||
const { nodeId, fieldName, value, generator } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
const result = zStatefulFieldValue.safeParse(value);
|
||||
|
||||
if (!field || !result.success) {
|
||||
return;
|
||||
}
|
||||
|
||||
field.value = result.data;
|
||||
|
||||
if (isFloatFieldCollectionInputInstance(field) && generator?.type === 'float-range-generator-start-step-count') {
|
||||
field.generator = generator;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generator?.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generator;
|
||||
}
|
||||
},
|
||||
fieldStringValueChanged: (state, action: FieldValueAction<StringFieldValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldValue);
|
||||
},
|
||||
fieldStringCollectionValueChanged: (state, action: FieldValueAction<StringFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zStringFieldCollectionValue);
|
||||
},
|
||||
fieldNumberValueChanged: (state, action: FieldValueAction<IntegerFieldValue | FloatFieldValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldValue.or(zFloatFieldValue));
|
||||
},
|
||||
fieldNumberCollectionValueChanged: (state, action: FieldValueAction<IntegerFieldCollectionValue>) => {
|
||||
fieldValueReducer(state, action, zIntegerFieldCollectionValue.or(zFloatFieldCollectionValue));
|
||||
},
|
||||
fieldNumberCollectionGeneratorToggled: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (isFloatFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultFloatRangeStartStepCountGenerator();
|
||||
} else if (isIntegerFieldCollectionInputInstance(field)) {
|
||||
field.generator = field.generator ? undefined : getDefaultIntegerRangeStartStepCountGenerator();
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorStateChanged: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
nodeId: string;
|
||||
fieldName: string;
|
||||
generatorState: FloatRangeStartStepCountGenerator | IntegerRangeStartStepCountGenerator;
|
||||
}>
|
||||
) => {
|
||||
const { nodeId, fieldName, generatorState } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
generatorState.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.generator = generatorState;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionGeneratorCommitted: (state, action: PayloadAction<{ nodeId: string; fieldName: string }>) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'float-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = floatRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else if (
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
field.generator &&
|
||||
field.generator.type === 'integer-range-generator-start-step-count'
|
||||
) {
|
||||
field.value = integerRangeStartStepCountGenerator(field.generator);
|
||||
field.generator = undefined;
|
||||
} else {
|
||||
// This should never happen
|
||||
}
|
||||
},
|
||||
fieldNumberCollectionLockLinearViewToggled: (
|
||||
state,
|
||||
action: PayloadAction<{ nodeId: string; fieldName: string }>
|
||||
) => {
|
||||
const { nodeId, fieldName } = action.payload;
|
||||
const field = selectField(state, nodeId, fieldName);
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
if (!isFloatFieldCollectionInputInstance(field) && !isIntegerFieldCollectionInputInstance(field)) {
|
||||
return;
|
||||
}
|
||||
field.lockLinearView = !field.lockLinearView;
|
||||
},
|
||||
fieldBooleanValueChanged: (state, action: FieldValueAction<BooleanFieldValue>) => {
|
||||
fieldValueReducer(state, action, zBooleanFieldValue);
|
||||
},
|
||||
@@ -435,9 +578,15 @@ export const {
|
||||
fieldModelIdentifierValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldNumberCollectionGeneratorToggled,
|
||||
fieldNumberCollectionGeneratorStateChanged,
|
||||
fieldNumberCollectionGeneratorCommitted,
|
||||
fieldNumberCollectionLockLinearViewToggled,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
@@ -546,9 +695,11 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldMainModelValueChanged,
|
||||
fieldNumberValueChanged,
|
||||
fieldNumberCollectionValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
fieldStringValueChanged,
|
||||
fieldStringCollectionValueChanged,
|
||||
fieldVaeModelValueChanged,
|
||||
fieldT5EncoderValueChanged,
|
||||
fieldCLIPEmbedValueChanged,
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import type {
|
||||
FieldIdentifier,
|
||||
FieldInputInstance,
|
||||
FieldInputTemplate,
|
||||
FieldOutputTemplate,
|
||||
StatefulFieldValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import type {
|
||||
AnyNode,
|
||||
@@ -31,15 +31,15 @@ export type NodesState = {
|
||||
};
|
||||
|
||||
export type WorkflowMode = 'edit' | 'view';
|
||||
export type FieldIdentifierWithValue = FieldIdentifier & {
|
||||
value: StatefulFieldValue;
|
||||
export type FieldIdentifierWithInstance = FieldIdentifier & {
|
||||
field: FieldInputInstance;
|
||||
};
|
||||
|
||||
export type WorkflowsState = Omit<WorkflowV3, 'nodes' | 'edges'> & {
|
||||
_version: 1;
|
||||
_version: 2;
|
||||
isTouched: boolean;
|
||||
mode: WorkflowMode;
|
||||
originalExposedFieldValues: FieldIdentifierWithValue[];
|
||||
originalExposedFieldValues: FieldIdentifierWithInstance[];
|
||||
searchTerm: string;
|
||||
orderBy?: WorkflowRecordOrderBy;
|
||||
orderDirection: SQLiteDirection;
|
||||
|
||||
@@ -5,7 +5,7 @@ import { deepClone } from 'common/util/deepClone';
|
||||
import { workflowLoaded } from 'features/nodes/store/actions';
|
||||
import { isAnyNodeOrEdgeMutation, nodeEditorReset, nodesChanged } from 'features/nodes/store/nodesSlice';
|
||||
import type {
|
||||
FieldIdentifierWithValue,
|
||||
FieldIdentifierWithInstance,
|
||||
WorkflowMode,
|
||||
WorkflowsState as WorkflowState,
|
||||
} from 'features/nodes/store/types';
|
||||
@@ -31,7 +31,7 @@ const blankWorkflow: Omit<WorkflowV3, 'nodes' | 'edges'> = {
|
||||
};
|
||||
|
||||
const initialWorkflowState: WorkflowState = {
|
||||
_version: 1,
|
||||
_version: 2,
|
||||
isTouched: false,
|
||||
mode: 'view',
|
||||
originalExposedFieldValues: [],
|
||||
@@ -62,7 +62,7 @@ export const workflowSlice = createSlice({
|
||||
const { id, isOpen } = action.payload;
|
||||
state.categorySections[id] = isOpen;
|
||||
},
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithValue>) => {
|
||||
workflowExposedFieldAdded: (state, action: PayloadAction<FieldIdentifierWithInstance>) => {
|
||||
state.exposedFields = uniqBy(
|
||||
state.exposedFields.concat(omit(action.payload, 'value')),
|
||||
(field) => `${field.nodeId}-${field.fieldName}`
|
||||
@@ -128,25 +128,25 @@ export const workflowSlice = createSlice({
|
||||
builder.addCase(workflowLoaded, (state, action) => {
|
||||
const { nodes, edges: _edges, ...workflowExtra } = action.payload;
|
||||
|
||||
const originalExposedFieldValues: FieldIdentifierWithValue[] = [];
|
||||
const originalExposedFieldValues: FieldIdentifierWithInstance[] = [];
|
||||
|
||||
workflowExtra.exposedFields.forEach((field) => {
|
||||
const node = nodes.find((n) => n.id === field.nodeId);
|
||||
workflowExtra.exposedFields.forEach(({ nodeId, fieldName }) => {
|
||||
const node = nodes.find((n) => n.id === nodeId);
|
||||
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const input = node.data.inputs[field.fieldName];
|
||||
const field = node.data.inputs[fieldName];
|
||||
|
||||
if (!input) {
|
||||
if (!field) {
|
||||
return;
|
||||
}
|
||||
|
||||
const originalExposedFieldValue = {
|
||||
nodeId: field.nodeId,
|
||||
fieldName: field.fieldName,
|
||||
value: input.value,
|
||||
nodeId,
|
||||
fieldName,
|
||||
field,
|
||||
};
|
||||
originalExposedFieldValues.push(originalExposedFieldValue);
|
||||
});
|
||||
@@ -243,6 +243,9 @@ const migrateWorkflowState = (state: any): any => {
|
||||
if (!('_version' in state)) {
|
||||
state._version = 1;
|
||||
}
|
||||
if (state._version === 1) {
|
||||
return deepClone(initialWorkflowState);
|
||||
}
|
||||
return state;
|
||||
};
|
||||
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
import {
|
||||
zFloatRangeStartStepCountGenerator,
|
||||
zIntegerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { buildTypeGuard } from 'features/parameters/types/parameterSchemas';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { zBoardField, zColorField, zImageField, zModelIdentifierField, zSchedulerField } from './common';
|
||||
@@ -78,14 +83,35 @@ const zIntegerFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('IntegerField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zIntegerCollectionFieldType = z.object({
|
||||
name: z.literal('IntegerField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isIntegerCollectionFieldType = buildTypeGuard(zIntegerCollectionFieldType);
|
||||
|
||||
const zFloatFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('FloatField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zFloatCollectionFieldType = z.object({
|
||||
name: z.literal('FloatField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isFloatCollectionFieldType = buildTypeGuard(zFloatCollectionFieldType);
|
||||
|
||||
const zStringFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('StringField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
const zStringCollectionFieldType = z.object({
|
||||
name: z.literal('StringField'),
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isStringCollectionFieldType = buildTypeGuard(zStringCollectionFieldType);
|
||||
|
||||
const zBooleanFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BooleanField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -103,9 +129,7 @@ const zImageCollectionFieldType = z.object({
|
||||
cardinality: z.literal(COLLECTION),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
});
|
||||
export const isImageCollectionFieldType = (
|
||||
fieldType: FieldType
|
||||
): fieldType is z.infer<typeof zImageCollectionFieldType> => zImageCollectionFieldType.safeParse(fieldType).success;
|
||||
export const isImageCollectionFieldType = buildTypeGuard(zImageCollectionFieldType);
|
||||
const zBoardFieldType = zFieldTypeBase.extend({
|
||||
name: z.literal('BoardField'),
|
||||
originalType: zStatelessFieldType.optional(),
|
||||
@@ -254,10 +278,48 @@ const zIntegerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type IntegerFieldValue = z.infer<typeof zIntegerFieldValue>;
|
||||
export type IntegerFieldInputInstance = z.infer<typeof zIntegerFieldInputInstance>;
|
||||
export type IntegerFieldInputTemplate = z.infer<typeof zIntegerFieldInputTemplate>;
|
||||
export const isIntegerFieldInputInstance = (val: unknown): val is IntegerFieldInputInstance =>
|
||||
zIntegerFieldInputInstance.safeParse(val).success;
|
||||
export const isIntegerFieldInputTemplate = (val: unknown): val is IntegerFieldInputTemplate =>
|
||||
zIntegerFieldInputTemplate.safeParse(val).success;
|
||||
export const isIntegerFieldInputInstance = buildTypeGuard(zIntegerFieldInputInstance);
|
||||
export const isIntegerFieldInputTemplate = buildTypeGuard(zIntegerFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region IntegerField Collection
|
||||
export const zIntegerFieldCollectionValue = z.array(zIntegerFieldValue).optional();
|
||||
const zIntegerFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zIntegerFieldCollectionValue,
|
||||
generator: zIntegerRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zIntegerFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zIntegerCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zIntegerFieldCollectionValue,
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().int().optional(),
|
||||
exclusiveMaximum: z.number().int().optional(),
|
||||
minimum: z.number().int().optional(),
|
||||
exclusiveMinimum: z.number().int().optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zIntegerFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zIntegerCollectionFieldType,
|
||||
});
|
||||
export type IntegerFieldCollectionValue = z.infer<typeof zIntegerFieldCollectionValue>;
|
||||
export type IntegerFieldCollectionInputInstance = z.infer<typeof zIntegerFieldCollectionInputInstance>;
|
||||
export type IntegerFieldCollectionInputTemplate = z.infer<typeof zIntegerFieldCollectionInputTemplate>;
|
||||
export const isIntegerFieldCollectionInputInstance = buildTypeGuard(zIntegerFieldCollectionInputInstance);
|
||||
export const isIntegerFieldCollectionInputTemplate = buildTypeGuard(zIntegerFieldCollectionInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region FloatField
|
||||
@@ -282,10 +344,48 @@ const zFloatFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type FloatFieldValue = z.infer<typeof zFloatFieldValue>;
|
||||
export type FloatFieldInputInstance = z.infer<typeof zFloatFieldInputInstance>;
|
||||
export type FloatFieldInputTemplate = z.infer<typeof zFloatFieldInputTemplate>;
|
||||
export const isFloatFieldInputInstance = (val: unknown): val is FloatFieldInputInstance =>
|
||||
zFloatFieldInputInstance.safeParse(val).success;
|
||||
export const isFloatFieldInputTemplate = (val: unknown): val is FloatFieldInputTemplate =>
|
||||
zFloatFieldInputTemplate.safeParse(val).success;
|
||||
export const isFloatFieldInputInstance = buildTypeGuard(zFloatFieldInputInstance);
|
||||
export const isFloatFieldInputTemplate = buildTypeGuard(zFloatFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region FloatField Collection
|
||||
|
||||
export const zFloatFieldCollectionValue = z.array(zFloatFieldValue).optional();
|
||||
const zFloatFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zFloatFieldCollectionValue,
|
||||
generator: zFloatRangeStartStepCountGenerator.optional(),
|
||||
lockLinearView: z.boolean().default(false),
|
||||
});
|
||||
const zFloatFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zFloatFieldCollectionValue,
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
multipleOf: z.number().int().optional(),
|
||||
maximum: z.number().optional(),
|
||||
exclusiveMaximum: z.number().optional(),
|
||||
minimum: z.number().optional(),
|
||||
exclusiveMinimum: z.number().optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zFloatFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zFloatCollectionFieldType,
|
||||
});
|
||||
export type FloatFieldCollectionInputInstance = z.infer<typeof zFloatFieldCollectionInputInstance>;
|
||||
export type FloatFieldCollectionInputTemplate = z.infer<typeof zFloatFieldCollectionInputTemplate>;
|
||||
export const isFloatFieldCollectionInputInstance = buildTypeGuard(zFloatFieldCollectionInputInstance);
|
||||
export const isFloatFieldCollectionInputTemplate = buildTypeGuard(zFloatFieldCollectionInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region StringField
|
||||
@@ -315,13 +415,55 @@ const zStringFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringFieldType,
|
||||
});
|
||||
|
||||
// #region StringField Collection
|
||||
export const zStringFieldCollectionValue = z.array(zStringFieldValue).optional();
|
||||
const zStringFieldCollectionInputInstance = zFieldInputInstanceBase.extend({
|
||||
value: zStringFieldCollectionValue,
|
||||
});
|
||||
const zStringFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
originalType: zFieldType.optional(),
|
||||
default: zStringFieldCollectionValue,
|
||||
maxLength: z.number().int().gte(0).optional(),
|
||||
minLength: z.number().int().gte(0).optional(),
|
||||
maxItems: z.number().int().gte(0).optional(),
|
||||
minItems: z.number().int().gte(0).optional(),
|
||||
})
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxLength !== undefined && val.minLength !== undefined) {
|
||||
return val.maxLength >= val.minLength;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
)
|
||||
.refine(
|
||||
(val) => {
|
||||
if (val.maxItems !== undefined && val.minItems !== undefined) {
|
||||
return val.maxItems >= val.minItems;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zStringFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
type: zStringCollectionFieldType,
|
||||
});
|
||||
export type StringFieldCollectionValue = z.infer<typeof zStringFieldCollectionValue>;
|
||||
export type StringFieldCollectionInputInstance = z.infer<typeof zStringFieldCollectionInputInstance>;
|
||||
export type StringFieldCollectionInputTemplate = z.infer<typeof zStringFieldCollectionInputTemplate>;
|
||||
export const isStringFieldCollectionInputInstance = buildTypeGuard(zStringFieldCollectionInputInstance);
|
||||
export const isStringFieldCollectionInputTemplate = buildTypeGuard(zStringFieldCollectionInputTemplate);
|
||||
// #endregion
|
||||
|
||||
export type StringFieldValue = z.infer<typeof zStringFieldValue>;
|
||||
export type StringFieldInputInstance = z.infer<typeof zStringFieldInputInstance>;
|
||||
export type StringFieldInputTemplate = z.infer<typeof zStringFieldInputTemplate>;
|
||||
export const isStringFieldInputInstance = (val: unknown): val is StringFieldInputInstance =>
|
||||
zStringFieldInputInstance.safeParse(val).success;
|
||||
export const isStringFieldInputTemplate = (val: unknown): val is StringFieldInputTemplate =>
|
||||
zStringFieldInputTemplate.safeParse(val).success;
|
||||
export const isStringFieldInputInstance = buildTypeGuard(zStringFieldInputInstance);
|
||||
export const isStringFieldInputTemplate = buildTypeGuard(zStringFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region BooleanField
|
||||
@@ -341,10 +483,8 @@ const zBooleanFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BooleanFieldValue = z.infer<typeof zBooleanFieldValue>;
|
||||
export type BooleanFieldInputInstance = z.infer<typeof zBooleanFieldInputInstance>;
|
||||
export type BooleanFieldInputTemplate = z.infer<typeof zBooleanFieldInputTemplate>;
|
||||
export const isBooleanFieldInputInstance = (val: unknown): val is BooleanFieldInputInstance =>
|
||||
zBooleanFieldInputInstance.safeParse(val).success;
|
||||
export const isBooleanFieldInputTemplate = (val: unknown): val is BooleanFieldInputTemplate =>
|
||||
zBooleanFieldInputTemplate.safeParse(val).success;
|
||||
export const isBooleanFieldInputInstance = buildTypeGuard(zBooleanFieldInputInstance);
|
||||
export const isBooleanFieldInputTemplate = buildTypeGuard(zBooleanFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region EnumField
|
||||
@@ -366,10 +506,8 @@ const zEnumFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type EnumFieldValue = z.infer<typeof zEnumFieldValue>;
|
||||
export type EnumFieldInputInstance = z.infer<typeof zEnumFieldInputInstance>;
|
||||
export type EnumFieldInputTemplate = z.infer<typeof zEnumFieldInputTemplate>;
|
||||
export const isEnumFieldInputInstance = (val: unknown): val is EnumFieldInputInstance =>
|
||||
zEnumFieldInputInstance.safeParse(val).success;
|
||||
export const isEnumFieldInputTemplate = (val: unknown): val is EnumFieldInputTemplate =>
|
||||
zEnumFieldInputTemplate.safeParse(val).success;
|
||||
export const isEnumFieldInputInstance = buildTypeGuard(zEnumFieldInputInstance);
|
||||
export const isEnumFieldInputTemplate = buildTypeGuard(zEnumFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region ImageField
|
||||
@@ -388,10 +526,8 @@ const zImageFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldValue = z.infer<typeof zImageFieldValue>;
|
||||
export type ImageFieldInputInstance = z.infer<typeof zImageFieldInputInstance>;
|
||||
export type ImageFieldInputTemplate = z.infer<typeof zImageFieldInputTemplate>;
|
||||
export const isImageFieldInputInstance = (val: unknown): val is ImageFieldInputInstance =>
|
||||
zImageFieldInputInstance.safeParse(val).success;
|
||||
export const isImageFieldInputTemplate = (val: unknown): val is ImageFieldInputTemplate =>
|
||||
zImageFieldInputTemplate.safeParse(val).success;
|
||||
export const isImageFieldInputInstance = buildTypeGuard(zImageFieldInputInstance);
|
||||
export const isImageFieldInputTemplate = buildTypeGuard(zImageFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region ImageField Collection
|
||||
@@ -414,7 +550,7 @@ const zImageFieldCollectionInputTemplate = zFieldInputTemplateBase
|
||||
}
|
||||
return true;
|
||||
},
|
||||
{ message: 'maxLength must be greater than or equal to minLength' }
|
||||
{ message: 'maxItems must be greater than or equal to minItems' }
|
||||
);
|
||||
|
||||
const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
@@ -423,10 +559,8 @@ const zImageFieldCollectionOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ImageFieldCollectionValue = z.infer<typeof zImageFieldCollectionValue>;
|
||||
export type ImageFieldCollectionInputInstance = z.infer<typeof zImageFieldCollectionInputInstance>;
|
||||
export type ImageFieldCollectionInputTemplate = z.infer<typeof zImageFieldCollectionInputTemplate>;
|
||||
export const isImageFieldCollectionInputInstance = (val: unknown): val is ImageFieldCollectionInputInstance =>
|
||||
zImageFieldCollectionInputInstance.safeParse(val).success;
|
||||
export const isImageFieldCollectionInputTemplate = (val: unknown): val is ImageFieldCollectionInputTemplate =>
|
||||
zImageFieldCollectionInputTemplate.safeParse(val).success;
|
||||
export const isImageFieldCollectionInputInstance = buildTypeGuard(zImageFieldCollectionInputInstance);
|
||||
export const isImageFieldCollectionInputTemplate = buildTypeGuard(zImageFieldCollectionInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region BoardField
|
||||
@@ -446,10 +580,8 @@ const zBoardFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type BoardFieldValue = z.infer<typeof zBoardFieldValue>;
|
||||
export type BoardFieldInputInstance = z.infer<typeof zBoardFieldInputInstance>;
|
||||
export type BoardFieldInputTemplate = z.infer<typeof zBoardFieldInputTemplate>;
|
||||
export const isBoardFieldInputInstance = (val: unknown): val is BoardFieldInputInstance =>
|
||||
zBoardFieldInputInstance.safeParse(val).success;
|
||||
export const isBoardFieldInputTemplate = (val: unknown): val is BoardFieldInputTemplate =>
|
||||
zBoardFieldInputTemplate.safeParse(val).success;
|
||||
export const isBoardFieldInputInstance = buildTypeGuard(zBoardFieldInputInstance);
|
||||
export const isBoardFieldInputTemplate = buildTypeGuard(zBoardFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region ColorField
|
||||
@@ -469,10 +601,8 @@ const zColorFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ColorFieldValue = z.infer<typeof zColorFieldValue>;
|
||||
export type ColorFieldInputInstance = z.infer<typeof zColorFieldInputInstance>;
|
||||
export type ColorFieldInputTemplate = z.infer<typeof zColorFieldInputTemplate>;
|
||||
export const isColorFieldInputInstance = (val: unknown): val is ColorFieldInputInstance =>
|
||||
zColorFieldInputInstance.safeParse(val).success;
|
||||
export const isColorFieldInputTemplate = (val: unknown): val is ColorFieldInputTemplate =>
|
||||
zColorFieldInputTemplate.safeParse(val).success;
|
||||
export const isColorFieldInputInstance = buildTypeGuard(zColorFieldInputInstance);
|
||||
export const isColorFieldInputTemplate = buildTypeGuard(zColorFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region MainModelField
|
||||
@@ -492,10 +622,8 @@ const zMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type MainModelFieldValue = z.infer<typeof zMainModelFieldValue>;
|
||||
export type MainModelFieldInputInstance = z.infer<typeof zMainModelFieldInputInstance>;
|
||||
export type MainModelFieldInputTemplate = z.infer<typeof zMainModelFieldInputTemplate>;
|
||||
export const isMainModelFieldInputInstance = (val: unknown): val is MainModelFieldInputInstance =>
|
||||
zMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isMainModelFieldInputTemplate = (val: unknown): val is MainModelFieldInputTemplate =>
|
||||
zMainModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isMainModelFieldInputInstance = buildTypeGuard(zMainModelFieldInputInstance);
|
||||
export const isMainModelFieldInputTemplate = buildTypeGuard(zMainModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region ModelIdentifierField
|
||||
@@ -514,10 +642,8 @@ const zModelIdentifierFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ModelIdentifierFieldValue = z.infer<typeof zModelIdentifierFieldValue>;
|
||||
export type ModelIdentifierFieldInputInstance = z.infer<typeof zModelIdentifierFieldInputInstance>;
|
||||
export type ModelIdentifierFieldInputTemplate = z.infer<typeof zModelIdentifierFieldInputTemplate>;
|
||||
export const isModelIdentifierFieldInputInstance = (val: unknown): val is ModelIdentifierFieldInputInstance =>
|
||||
zModelIdentifierFieldInputInstance.safeParse(val).success;
|
||||
export const isModelIdentifierFieldInputTemplate = (val: unknown): val is ModelIdentifierFieldInputTemplate =>
|
||||
zModelIdentifierFieldInputTemplate.safeParse(val).success;
|
||||
export const isModelIdentifierFieldInputInstance = buildTypeGuard(zModelIdentifierFieldInputInstance);
|
||||
export const isModelIdentifierFieldInputTemplate = buildTypeGuard(zModelIdentifierFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region SDXLMainModelField
|
||||
@@ -536,10 +662,8 @@ const zSDXLMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type SDXLMainModelFieldInputInstance = z.infer<typeof zSDXLMainModelFieldInputInstance>;
|
||||
export type SDXLMainModelFieldInputTemplate = z.infer<typeof zSDXLMainModelFieldInputTemplate>;
|
||||
export const isSDXLMainModelFieldInputInstance = (val: unknown): val is SDXLMainModelFieldInputInstance =>
|
||||
zSDXLMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLMainModelFieldInputTemplate = (val: unknown): val is SDXLMainModelFieldInputTemplate =>
|
||||
zSDXLMainModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isSDXLMainModelFieldInputInstance = buildTypeGuard(zSDXLMainModelFieldInputInstance);
|
||||
export const isSDXLMainModelFieldInputTemplate = buildTypeGuard(zSDXLMainModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region SD3MainModelField
|
||||
@@ -558,10 +682,8 @@ const zSD3MainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type SD3MainModelFieldInputInstance = z.infer<typeof zSD3MainModelFieldInputInstance>;
|
||||
export type SD3MainModelFieldInputTemplate = z.infer<typeof zSD3MainModelFieldInputTemplate>;
|
||||
export const isSD3MainModelFieldInputInstance = (val: unknown): val is SD3MainModelFieldInputInstance =>
|
||||
zSD3MainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputTemplate = (val: unknown): val is SD3MainModelFieldInputTemplate =>
|
||||
zSD3MainModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isSD3MainModelFieldInputInstance = buildTypeGuard(zSD3MainModelFieldInputInstance);
|
||||
export const isSD3MainModelFieldInputTemplate = buildTypeGuard(zSD3MainModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -581,10 +703,8 @@ const zFluxMainModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
});
|
||||
export type FluxMainModelFieldInputInstance = z.infer<typeof zFluxMainModelFieldInputInstance>;
|
||||
export type FluxMainModelFieldInputTemplate = z.infer<typeof zFluxMainModelFieldInputTemplate>;
|
||||
export const isFluxMainModelFieldInputInstance = (val: unknown): val is FluxMainModelFieldInputInstance =>
|
||||
zFluxMainModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxMainModelFieldInputTemplate = (val: unknown): val is FluxMainModelFieldInputTemplate =>
|
||||
zFluxMainModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isFluxMainModelFieldInputInstance = buildTypeGuard(zFluxMainModelFieldInputInstance);
|
||||
export const isFluxMainModelFieldInputTemplate = buildTypeGuard(zFluxMainModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -606,10 +726,8 @@ const zSDXLRefinerModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SDXLRefinerModelFieldValue = z.infer<typeof zSDXLRefinerModelFieldValue>;
|
||||
export type SDXLRefinerModelFieldInputInstance = z.infer<typeof zSDXLRefinerModelFieldInputInstance>;
|
||||
export type SDXLRefinerModelFieldInputTemplate = z.infer<typeof zSDXLRefinerModelFieldInputTemplate>;
|
||||
export const isSDXLRefinerModelFieldInputInstance = (val: unknown): val is SDXLRefinerModelFieldInputInstance =>
|
||||
zSDXLRefinerModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSDXLRefinerModelFieldInputTemplate = (val: unknown): val is SDXLRefinerModelFieldInputTemplate =>
|
||||
zSDXLRefinerModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isSDXLRefinerModelFieldInputInstance = buildTypeGuard(zSDXLRefinerModelFieldInputInstance);
|
||||
export const isSDXLRefinerModelFieldInputTemplate = buildTypeGuard(zSDXLRefinerModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region VAEModelField
|
||||
@@ -629,10 +747,8 @@ const zVAEModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type VAEModelFieldValue = z.infer<typeof zVAEModelFieldValue>;
|
||||
export type VAEModelFieldInputInstance = z.infer<typeof zVAEModelFieldInputInstance>;
|
||||
export type VAEModelFieldInputTemplate = z.infer<typeof zVAEModelFieldInputTemplate>;
|
||||
export const isVAEModelFieldInputInstance = (val: unknown): val is VAEModelFieldInputInstance =>
|
||||
zVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isVAEModelFieldInputTemplate = (val: unknown): val is VAEModelFieldInputTemplate =>
|
||||
zVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isVAEModelFieldInputInstance = buildTypeGuard(zVAEModelFieldInputInstance);
|
||||
export const isVAEModelFieldInputTemplate = buildTypeGuard(zVAEModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region LoRAModelField
|
||||
@@ -652,10 +768,8 @@ const zLoRAModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type LoRAModelFieldValue = z.infer<typeof zLoRAModelFieldValue>;
|
||||
export type LoRAModelFieldInputInstance = z.infer<typeof zLoRAModelFieldInputInstance>;
|
||||
export type LoRAModelFieldInputTemplate = z.infer<typeof zLoRAModelFieldInputTemplate>;
|
||||
export const isLoRAModelFieldInputInstance = (val: unknown): val is LoRAModelFieldInputInstance =>
|
||||
zLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isLoRAModelFieldInputTemplate = (val: unknown): val is LoRAModelFieldInputTemplate =>
|
||||
zLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isLoRAModelFieldInputInstance = buildTypeGuard(zLoRAModelFieldInputInstance);
|
||||
export const isLoRAModelFieldInputTemplate = buildTypeGuard(zLoRAModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region ControlNetModelField
|
||||
@@ -675,10 +789,8 @@ const zControlNetModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type ControlNetModelFieldValue = z.infer<typeof zControlNetModelFieldValue>;
|
||||
export type ControlNetModelFieldInputInstance = z.infer<typeof zControlNetModelFieldInputInstance>;
|
||||
export type ControlNetModelFieldInputTemplate = z.infer<typeof zControlNetModelFieldInputTemplate>;
|
||||
export const isControlNetModelFieldInputInstance = (val: unknown): val is ControlNetModelFieldInputInstance =>
|
||||
zControlNetModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlNetModelFieldInputTemplate = (val: unknown): val is ControlNetModelFieldInputTemplate =>
|
||||
zControlNetModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isControlNetModelFieldInputInstance = buildTypeGuard(zControlNetModelFieldInputInstance);
|
||||
export const isControlNetModelFieldInputTemplate = buildTypeGuard(zControlNetModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region IPAdapterModelField
|
||||
@@ -698,10 +810,8 @@ const zIPAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type IPAdapterModelFieldValue = z.infer<typeof zIPAdapterModelFieldValue>;
|
||||
export type IPAdapterModelFieldInputInstance = z.infer<typeof zIPAdapterModelFieldInputInstance>;
|
||||
export type IPAdapterModelFieldInputTemplate = z.infer<typeof zIPAdapterModelFieldInputTemplate>;
|
||||
export const isIPAdapterModelFieldInputInstance = (val: unknown): val is IPAdapterModelFieldInputInstance =>
|
||||
zIPAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldInputTemplate = (val: unknown): val is IPAdapterModelFieldInputTemplate =>
|
||||
zIPAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isIPAdapterModelFieldInputInstance = buildTypeGuard(zIPAdapterModelFieldInputInstance);
|
||||
export const isIPAdapterModelFieldInputTemplate = buildTypeGuard(zIPAdapterModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region T2IAdapterField
|
||||
@@ -721,10 +831,8 @@ const zT2IAdapterModelFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type T2IAdapterModelFieldValue = z.infer<typeof zT2IAdapterModelFieldValue>;
|
||||
export type T2IAdapterModelFieldInputInstance = z.infer<typeof zT2IAdapterModelFieldInputInstance>;
|
||||
export type T2IAdapterModelFieldInputTemplate = z.infer<typeof zT2IAdapterModelFieldInputTemplate>;
|
||||
export const isT2IAdapterModelFieldInputInstance = (val: unknown): val is T2IAdapterModelFieldInputInstance =>
|
||||
zT2IAdapterModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT2IAdapterModelFieldInputTemplate = (val: unknown): val is T2IAdapterModelFieldInputTemplate =>
|
||||
zT2IAdapterModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isT2IAdapterModelFieldInputInstance = buildTypeGuard(zT2IAdapterModelFieldInputInstance);
|
||||
export const isT2IAdapterModelFieldInputTemplate = buildTypeGuard(zT2IAdapterModelFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region SpandrelModelToModelField
|
||||
@@ -744,14 +852,12 @@ const zSpandrelImageToImageModelFieldOutputTemplate = zFieldOutputTemplateBase.e
|
||||
export type SpandrelImageToImageModelFieldValue = z.infer<typeof zSpandrelImageToImageModelFieldValue>;
|
||||
export type SpandrelImageToImageModelFieldInputInstance = z.infer<typeof zSpandrelImageToImageModelFieldInputInstance>;
|
||||
export type SpandrelImageToImageModelFieldInputTemplate = z.infer<typeof zSpandrelImageToImageModelFieldInputTemplate>;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputInstance =>
|
||||
zSpandrelImageToImageModelFieldInputInstance.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = (
|
||||
val: unknown
|
||||
): val is SpandrelImageToImageModelFieldInputTemplate =>
|
||||
zSpandrelImageToImageModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isSpandrelImageToImageModelFieldInputInstance = buildTypeGuard(
|
||||
zSpandrelImageToImageModelFieldInputInstance
|
||||
);
|
||||
export const isSpandrelImageToImageModelFieldInputTemplate = buildTypeGuard(
|
||||
zSpandrelImageToImageModelFieldInputTemplate
|
||||
);
|
||||
// #endregion
|
||||
|
||||
// #region T5EncoderModelField
|
||||
@@ -770,10 +876,8 @@ export type T5EncoderModelFieldValue = z.infer<typeof zT5EncoderModelFieldValue>
|
||||
|
||||
export type T5EncoderModelFieldInputInstance = z.infer<typeof zT5EncoderModelFieldInputInstance>;
|
||||
export type T5EncoderModelFieldInputTemplate = z.infer<typeof zT5EncoderModelFieldInputTemplate>;
|
||||
export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5EncoderModelFieldInputInstance =>
|
||||
zT5EncoderModelFieldInputInstance.safeParse(val).success;
|
||||
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
|
||||
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isT5EncoderModelFieldInputInstance = buildTypeGuard(zT5EncoderModelFieldInputInstance);
|
||||
export const isT5EncoderModelFieldInputTemplate = buildTypeGuard(zT5EncoderModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -793,10 +897,8 @@ export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
|
||||
|
||||
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
|
||||
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
|
||||
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
|
||||
zFluxVAEModelFieldInputInstance.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
|
||||
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isFluxVAEModelFieldInputInstance = buildTypeGuard(zFluxVAEModelFieldInputInstance);
|
||||
export const isFluxVAEModelFieldInputTemplate = buildTypeGuard(zFluxVAEModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -816,10 +918,8 @@ export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>
|
||||
|
||||
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
|
||||
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
|
||||
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
|
||||
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isCLIPEmbedModelFieldInputInstance = buildTypeGuard(zCLIPEmbedModelFieldInputInstance);
|
||||
export const isCLIPEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPEmbedModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -839,10 +939,8 @@ export type CLIPLEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
|
||||
|
||||
export type CLIPLEmbedModelFieldInputInstance = z.infer<typeof zCLIPLEmbedModelFieldInputInstance>;
|
||||
export type CLIPLEmbedModelFieldInputTemplate = z.infer<typeof zCLIPLEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPLEmbedModelFieldInputInstance = (val: unknown): val is CLIPLEmbedModelFieldInputInstance =>
|
||||
zCLIPLEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPLEmbedModelFieldInputTemplate = (val: unknown): val is CLIPLEmbedModelFieldInputTemplate =>
|
||||
zCLIPLEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isCLIPLEmbedModelFieldInputInstance = buildTypeGuard(zCLIPLEmbedModelFieldInputInstance);
|
||||
export const isCLIPLEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPLEmbedModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -862,10 +960,8 @@ export type CLIPGEmbedModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldValu
|
||||
|
||||
export type CLIPGEmbedModelFieldInputInstance = z.infer<typeof zCLIPGEmbedModelFieldInputInstance>;
|
||||
export type CLIPGEmbedModelFieldInputTemplate = z.infer<typeof zCLIPGEmbedModelFieldInputTemplate>;
|
||||
export const isCLIPGEmbedModelFieldInputInstance = (val: unknown): val is CLIPGEmbedModelFieldInputInstance =>
|
||||
zCLIPGEmbedModelFieldInputInstance.safeParse(val).success;
|
||||
export const isCLIPGEmbedModelFieldInputTemplate = (val: unknown): val is CLIPGEmbedModelFieldInputTemplate =>
|
||||
zCLIPGEmbedModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isCLIPGEmbedModelFieldInputInstance = buildTypeGuard(zCLIPGEmbedModelFieldInputInstance);
|
||||
export const isCLIPGEmbedModelFieldInputTemplate = buildTypeGuard(zCLIPGEmbedModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -885,10 +981,8 @@ export type ControlLoRAModelFieldValue = z.infer<typeof zCLIPLEmbedModelFieldVal
|
||||
|
||||
export type ControlLoRAModelFieldInputInstance = z.infer<typeof zControlLoRAModelFieldInputInstance>;
|
||||
export type ControlLoRAModelFieldInputTemplate = z.infer<typeof zControlLoRAModelFieldInputTemplate>;
|
||||
export const isControlLoRAModelFieldInputInstance = (val: unknown): val is ControlLoRAModelFieldInputInstance =>
|
||||
zControlLoRAModelFieldInputInstance.safeParse(val).success;
|
||||
export const isControlLoRAModelFieldInputTemplate = (val: unknown): val is ControlLoRAModelFieldInputTemplate =>
|
||||
zControlLoRAModelFieldInputTemplate.safeParse(val).success;
|
||||
export const isControlLoRAModelFieldInputInstance = buildTypeGuard(zControlLoRAModelFieldInputInstance);
|
||||
export const isControlLoRAModelFieldInputTemplate = buildTypeGuard(zControlLoRAModelFieldInputTemplate);
|
||||
|
||||
// #endregion
|
||||
|
||||
@@ -909,10 +1003,8 @@ const zSchedulerFieldOutputTemplate = zFieldOutputTemplateBase.extend({
|
||||
export type SchedulerFieldValue = z.infer<typeof zSchedulerFieldValue>;
|
||||
export type SchedulerFieldInputInstance = z.infer<typeof zSchedulerFieldInputInstance>;
|
||||
export type SchedulerFieldInputTemplate = z.infer<typeof zSchedulerFieldInputTemplate>;
|
||||
export const isSchedulerFieldInputInstance = (val: unknown): val is SchedulerFieldInputInstance =>
|
||||
zSchedulerFieldInputInstance.safeParse(val).success;
|
||||
export const isSchedulerFieldInputTemplate = (val: unknown): val is SchedulerFieldInputTemplate =>
|
||||
zSchedulerFieldInputTemplate.safeParse(val).success;
|
||||
export const isSchedulerFieldInputInstance = buildTypeGuard(zSchedulerFieldInputInstance);
|
||||
export const isSchedulerFieldInputTemplate = buildTypeGuard(zSchedulerFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region StatelessField
|
||||
@@ -963,8 +1055,11 @@ export type StatelessFieldInputTemplate = z.infer<typeof zStatelessFieldInputTem
|
||||
// #region StatefulFieldValue & FieldValue
|
||||
export const zStatefulFieldValue = z.union([
|
||||
zIntegerFieldValue,
|
||||
zIntegerFieldCollectionValue,
|
||||
zFloatFieldValue,
|
||||
zFloatFieldCollectionValue,
|
||||
zStringFieldValue,
|
||||
zStringFieldCollectionValue,
|
||||
zBooleanFieldValue,
|
||||
zEnumFieldValue,
|
||||
zImageFieldValue,
|
||||
@@ -1000,8 +1095,11 @@ export type FieldValue = z.infer<typeof zFieldValue>;
|
||||
// #region StatefulFieldInputInstance & FieldInputInstance
|
||||
const zStatefulFieldInputInstance = z.union([
|
||||
zIntegerFieldInputInstance,
|
||||
zIntegerFieldCollectionInputInstance,
|
||||
zFloatFieldInputInstance,
|
||||
zFloatFieldCollectionInputInstance,
|
||||
zStringFieldInputInstance,
|
||||
zStringFieldCollectionInputInstance,
|
||||
zBooleanFieldInputInstance,
|
||||
zEnumFieldInputInstance,
|
||||
zImageFieldInputInstance,
|
||||
@@ -1028,15 +1126,17 @@ const zStatefulFieldInputInstance = z.union([
|
||||
|
||||
export const zFieldInputInstance = z.union([zStatefulFieldInputInstance, zStatelessFieldInputInstance]);
|
||||
export type FieldInputInstance = z.infer<typeof zFieldInputInstance>;
|
||||
export const isFieldInputInstance = (val: unknown): val is FieldInputInstance =>
|
||||
zFieldInputInstance.safeParse(val).success;
|
||||
export const isFieldInputInstance = buildTypeGuard(zFieldInputInstance);
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldInputTemplate & FieldInputTemplate
|
||||
const zStatefulFieldInputTemplate = z.union([
|
||||
zIntegerFieldInputTemplate,
|
||||
zIntegerFieldCollectionInputTemplate,
|
||||
zFloatFieldInputTemplate,
|
||||
zFloatFieldCollectionInputTemplate,
|
||||
zStringFieldInputTemplate,
|
||||
zStringFieldCollectionInputTemplate,
|
||||
zBooleanFieldInputTemplate,
|
||||
zEnumFieldInputTemplate,
|
||||
zImageFieldInputTemplate,
|
||||
@@ -1067,15 +1167,17 @@ const zStatefulFieldInputTemplate = z.union([
|
||||
|
||||
export const zFieldInputTemplate = z.union([zStatefulFieldInputTemplate, zStatelessFieldInputTemplate]);
|
||||
export type FieldInputTemplate = z.infer<typeof zFieldInputTemplate>;
|
||||
export const isFieldInputTemplate = (val: unknown): val is FieldInputTemplate =>
|
||||
zFieldInputTemplate.safeParse(val).success;
|
||||
export const isFieldInputTemplate = buildTypeGuard(zFieldInputTemplate);
|
||||
// #endregion
|
||||
|
||||
// #region StatefulFieldOutputTemplate & FieldOutputTemplate
|
||||
const zStatefulFieldOutputTemplate = z.union([
|
||||
zIntegerFieldOutputTemplate,
|
||||
zIntegerFieldCollectionOutputTemplate,
|
||||
zFloatFieldOutputTemplate,
|
||||
zFloatFieldCollectionOutputTemplate,
|
||||
zStringFieldOutputTemplate,
|
||||
zStringFieldCollectionOutputTemplate,
|
||||
zBooleanFieldOutputTemplate,
|
||||
zEnumFieldOutputTemplate,
|
||||
zImageFieldOutputTemplate,
|
||||
|
||||
@@ -0,0 +1,133 @@
|
||||
import type {
|
||||
FloatFieldCollectionInputInstance,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldCollectionValue,
|
||||
IntegerFieldCollectionInputInstance,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldCollectionValue,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
floatRangeStartStepCountGenerator,
|
||||
integerRangeStartStepCountGenerator,
|
||||
} from 'features/nodes/types/generators';
|
||||
import { t } from 'i18next';
|
||||
|
||||
export const validateImageFieldCollectionValue = (
|
||||
value: NonNullable<ImageFieldCollectionValue>,
|
||||
template: ImageFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const validateStringFieldCollectionValue = (
|
||||
value: NonNullable<StringFieldCollectionValue>,
|
||||
template: StringFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minLength, maxLength } = template;
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const str of value) {
|
||||
if (maxLength !== undefined && str.length > maxLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooLong', { value, maxLength }));
|
||||
}
|
||||
if (minLength !== undefined && str.length < minLength) {
|
||||
reasons.push(t('parameters.invoke.collectionStringTooShort', { value, minLength }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
|
||||
export const resolveNumberFieldCollectionValue = (
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance
|
||||
): number[] | undefined => {
|
||||
if (field.generator?.type === 'float-range-generator-start-step-count') {
|
||||
return floatRangeStartStepCountGenerator(field.generator);
|
||||
} else if (field.generator?.type === 'integer-range-generator-start-step-count') {
|
||||
return integerRangeStartStepCountGenerator(field.generator);
|
||||
} else {
|
||||
return field.value;
|
||||
}
|
||||
};
|
||||
|
||||
export const validateNumberFieldCollectionValue = (
|
||||
field: IntegerFieldCollectionInputInstance | FloatFieldCollectionInputInstance,
|
||||
template: IntegerFieldCollectionInputTemplate | FloatFieldCollectionInputTemplate
|
||||
): string[] => {
|
||||
const reasons: string[] = [];
|
||||
const { minItems, maxItems, minimum, maximum, exclusiveMinimum, exclusiveMaximum, multipleOf } = template;
|
||||
const value = resolveNumberFieldCollectionValue(field);
|
||||
|
||||
if (value === undefined) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
return reasons;
|
||||
}
|
||||
|
||||
const count = value.length;
|
||||
|
||||
// Image collections may have min or max items to validate
|
||||
if (minItems !== undefined && minItems > 0 && count === 0) {
|
||||
reasons.push(t('parameters.invoke.collectionEmpty'));
|
||||
}
|
||||
|
||||
if (minItems !== undefined && count < minItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooFewItems', { count, minItems }));
|
||||
}
|
||||
|
||||
if (maxItems !== undefined && count > maxItems) {
|
||||
reasons.push(t('parameters.invoke.collectionTooManyItems', { count, maxItems }));
|
||||
}
|
||||
|
||||
for (const num of value) {
|
||||
if (maximum !== undefined && num > maximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTMax', { value, maximum }));
|
||||
}
|
||||
if (minimum !== undefined && num < minimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTMin', { value, minimum }));
|
||||
}
|
||||
if (exclusiveMaximum !== undefined && num >= exclusiveMaximum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberGTExclusiveMax', { value, exclusiveMaximum }));
|
||||
}
|
||||
if (exclusiveMinimum !== undefined && num <= exclusiveMinimum) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberLTExclusiveMin', { value, exclusiveMinimum }));
|
||||
}
|
||||
if (multipleOf !== undefined && num % multipleOf !== 0) {
|
||||
reasons.push(t('parameters.invoke.collectionNumberNotMultipleOf', { value, multipleOf }));
|
||||
}
|
||||
}
|
||||
|
||||
return reasons;
|
||||
};
|
||||
29
invokeai/frontend/web/src/features/nodes/types/generators.ts
Normal file
29
invokeai/frontend/web/src/features/nodes/types/generators.ts
Normal file
@@ -0,0 +1,29 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
export const zFloatRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('float-range-generator-start-step-count').default('float-range-generator-start-step-count'),
|
||||
start: z.number().default(0),
|
||||
step: z.number().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type FloatRangeStartStepCountGenerator = z.infer<typeof zFloatRangeStartStepCountGenerator>;
|
||||
export const floatRangeStartStepCountGenerator = (generator: FloatRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultFloatRangeStartStepCountGenerator = (): FloatRangeStartStepCountGenerator =>
|
||||
zFloatRangeStartStepCountGenerator.parse({});
|
||||
|
||||
export const zIntegerRangeStartStepCountGenerator = z.object({
|
||||
type: z.literal('integer-range-generator-start-step-count').default('integer-range-generator-start-step-count'),
|
||||
start: z.number().int().default(0),
|
||||
step: z.number().int().default(1),
|
||||
count: z.number().int().default(10),
|
||||
});
|
||||
export type IntegerRangeStartStepCountGenerator = z.infer<typeof zIntegerRangeStartStepCountGenerator>;
|
||||
export const integerRangeStartStepCountGenerator = (generator: IntegerRangeStartStepCountGenerator): number[] => {
|
||||
const { start, step, count } = generator;
|
||||
return Array.from({ length: count }, (_, i) => start + i * step);
|
||||
};
|
||||
export const getDefaultIntegerRangeStartStepCountGenerator = (): IntegerRangeStartStepCountGenerator =>
|
||||
zIntegerRangeStartStepCountGenerator.parse({});
|
||||
@@ -91,3 +91,15 @@ const zInvocationNodeEdgeExtra = z.object({
|
||||
type InvocationNodeEdgeExtra = z.infer<typeof zInvocationNodeEdgeExtra>;
|
||||
export type InvocationNodeEdge = Edge<InvocationNodeEdgeExtra>;
|
||||
// #endregion
|
||||
|
||||
export const isBatchNode = (node: InvocationNode) => {
|
||||
switch (node.data.type) {
|
||||
case 'image_batch':
|
||||
case 'string_batch':
|
||||
case 'integer_batch':
|
||||
case 'float_batch':
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { NodesState } from 'features/nodes/store/types';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { omit, reduce } from 'lodash-es';
|
||||
import { isFloatFieldCollectionInputInstance, isIntegerFieldCollectionInputInstance } from 'features/nodes/types/field';
|
||||
import { resolveNumberFieldCollectionValue } from 'features/nodes/types/fieldValidators';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { negate, omit, reduce } from 'lodash-es';
|
||||
import type { AnyInvocation, Graph } from 'services/api/types';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
|
||||
@@ -14,7 +16,7 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const { nodes, edges } = nodesState;
|
||||
|
||||
// Exclude all batch nodes - we will handle these in the batch setup in a diff function
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter((node) => node.data.type !== 'image_batch');
|
||||
const filteredNodes = nodes.filter(isInvocationNode).filter(negate(isBatchNode));
|
||||
|
||||
// Reduce the node editor nodes into invocation graph nodes
|
||||
const parsedNodes = filteredNodes.reduce<NonNullable<Graph['nodes']>>((nodesAccumulator, node) => {
|
||||
@@ -25,7 +27,11 @@ export const buildNodesGraph = (nodesState: NodesState): Graph => {
|
||||
const transformedInputs = reduce(
|
||||
inputs,
|
||||
(inputsAccumulator, input, name) => {
|
||||
inputsAccumulator[name] = input.value;
|
||||
if (isFloatFieldCollectionInputInstance(input) || isIntegerFieldCollectionInputInstance(input)) {
|
||||
inputsAccumulator[name] = resolveNumberFieldCollectionValue(input);
|
||||
} else {
|
||||
inputsAccumulator[name] = input.value;
|
||||
}
|
||||
|
||||
return inputsAccumulator;
|
||||
},
|
||||
|
||||
@@ -11,11 +11,13 @@ import type {
|
||||
EnumFieldInputTemplate,
|
||||
FieldInputTemplate,
|
||||
FieldType,
|
||||
FloatFieldCollectionInputTemplate,
|
||||
FloatFieldInputTemplate,
|
||||
FluxMainModelFieldInputTemplate,
|
||||
FluxVAEModelFieldInputTemplate,
|
||||
ImageFieldCollectionInputTemplate,
|
||||
ImageFieldInputTemplate,
|
||||
IntegerFieldCollectionInputTemplate,
|
||||
IntegerFieldInputTemplate,
|
||||
IPAdapterModelFieldInputTemplate,
|
||||
LoRAModelFieldInputTemplate,
|
||||
@@ -28,12 +30,19 @@ import type {
|
||||
SpandrelImageToImageModelFieldInputTemplate,
|
||||
StatefulFieldType,
|
||||
StatelessFieldInputTemplate,
|
||||
StringFieldCollectionInputTemplate,
|
||||
StringFieldInputTemplate,
|
||||
T2IAdapterModelFieldInputTemplate,
|
||||
T5EncoderModelFieldInputTemplate,
|
||||
VAEModelFieldInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import { isImageCollectionFieldType, isStatefulFieldType } from 'features/nodes/types/field';
|
||||
import {
|
||||
isFloatCollectionFieldType,
|
||||
isImageCollectionFieldType,
|
||||
isIntegerCollectionFieldType,
|
||||
isStatefulFieldType,
|
||||
isStringCollectionFieldType,
|
||||
} from 'features/nodes/types/field';
|
||||
import type { InvocationFieldSchema } from 'features/nodes/types/openapi';
|
||||
import { isSchemaObject } from 'features/nodes/types/openapi';
|
||||
import { t } from 'i18next';
|
||||
@@ -77,6 +86,48 @@ const buildIntegerFieldInputTemplate: FieldInputTemplateBuilder<IntegerFieldInpu
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIntegerFieldCollectionInputTemplate: FieldInputTemplateBuilder<IntegerFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: IntegerFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -111,6 +162,48 @@ const buildFloatFieldInputTemplate: FieldInputTemplateBuilder<FloatFieldInputTem
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildFloatFieldCollectionInputTemplate: FieldInputTemplateBuilder<FloatFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: FloatFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
if (schemaObject.multipleOf !== undefined) {
|
||||
template.multipleOf = schemaObject.multipleOf;
|
||||
}
|
||||
|
||||
if (schemaObject.maximum !== undefined) {
|
||||
template.maximum = schemaObject.maximum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMaximum !== undefined && isNumber(schemaObject.exclusiveMaximum)) {
|
||||
template.exclusiveMaximum = schemaObject.exclusiveMaximum;
|
||||
}
|
||||
|
||||
if (schemaObject.minimum !== undefined) {
|
||||
template.minimum = schemaObject.minimum;
|
||||
}
|
||||
|
||||
if (schemaObject.exclusiveMinimum !== undefined && isNumber(schemaObject.exclusiveMinimum)) {
|
||||
template.exclusiveMinimum = schemaObject.exclusiveMinimum;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -133,6 +226,36 @@ const buildStringFieldInputTemplate: FieldInputTemplateBuilder<StringFieldInputT
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildStringFieldCollectionInputTemplate: FieldInputTemplateBuilder<StringFieldCollectionInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
fieldType,
|
||||
}) => {
|
||||
const template: StringFieldCollectionInputTemplate = {
|
||||
...baseField,
|
||||
type: fieldType,
|
||||
default: schemaObject.default ?? (schemaObject.orig_required ? [] : undefined),
|
||||
};
|
||||
|
||||
if (schemaObject.minLength !== undefined) {
|
||||
template.minLength = schemaObject.minLength;
|
||||
}
|
||||
|
||||
if (schemaObject.maxLength !== undefined) {
|
||||
template.maxLength = schemaObject.maxLength;
|
||||
}
|
||||
|
||||
if (schemaObject.minItems !== undefined) {
|
||||
template.minItems = schemaObject.minItems;
|
||||
}
|
||||
|
||||
if (schemaObject.maxItems !== undefined) {
|
||||
template.maxItems = schemaObject.maxItems;
|
||||
}
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildBooleanFieldInputTemplate: FieldInputTemplateBuilder<BooleanFieldInputTemplate> = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@@ -569,12 +692,29 @@ export const buildFieldInputTemplate = (
|
||||
|
||||
if (isStatefulFieldType(fieldType)) {
|
||||
if (isImageCollectionFieldType(fieldType)) {
|
||||
fieldType;
|
||||
return buildImageFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isStringCollectionFieldType(fieldType)) {
|
||||
return buildStringFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isIntegerCollectionFieldType(fieldType)) {
|
||||
return buildIntegerFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else if (isFloatCollectionFieldType(fieldType)) {
|
||||
return buildFloatFieldCollectionInputTemplate({
|
||||
schemaObject: fieldSchema,
|
||||
baseField,
|
||||
fieldType,
|
||||
});
|
||||
} else {
|
||||
const builder = TEMPLATE_BUILDER_MAP[fieldType.name];
|
||||
const template = builder({
|
||||
|
||||
@@ -20,7 +20,7 @@ import { z } from 'zod';
|
||||
* @param schema The zod schema to create a type guard from.
|
||||
* @returns A type guard function for the schema.
|
||||
*/
|
||||
const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
export const buildTypeGuard = <T extends z.ZodTypeAny>(schema: T) => {
|
||||
return (val: unknown): val is z.infer<T> => schema.safeParse(val).success;
|
||||
};
|
||||
|
||||
|
||||
@@ -206,8 +206,16 @@ const QueueCountPredictionWorkflowsTab = memo(() => {
|
||||
const iterationsCount = useAppSelector(selectIterations);
|
||||
|
||||
const text = useMemo(() => {
|
||||
const generationCount = Math.min(batchSize * iterationsCount, 10000);
|
||||
const iterations = t('queue.iterations', { count: iterationsCount });
|
||||
if (batchSize === 'NO_BATCHES') {
|
||||
const generationCount = Math.min(10000, iterationsCount);
|
||||
const generations = t('queue.generations', { count: generationCount });
|
||||
return `${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
|
||||
}
|
||||
if (batchSize === 'INVALID') {
|
||||
return t('parameters.invoke.invalidBatchConfiguration');
|
||||
}
|
||||
const generationCount = Math.min(batchSize * iterationsCount, 10000);
|
||||
const generations = t('queue.generations', { count: generationCount });
|
||||
return `${batchSize} ${t('queue.batchSize')} \u00d7 ${iterationsCount} ${iterations} -> ${generationCount} ${generations}`.toLowerCase();
|
||||
}, [batchSize, iterationsCount, t]);
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
|
||||
import type { AppConfig } from 'app/types/invokeai';
|
||||
import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
|
||||
@@ -18,14 +19,31 @@ import { selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { NodesState, Templates } from 'features/nodes/store/types';
|
||||
import type { WorkflowSettingsState } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { selectWorkflowSettingsSlice } from 'features/nodes/store/workflowSettingsSlice';
|
||||
import { isImageFieldCollectionInputInstance, isImageFieldCollectionInputTemplate } from 'features/nodes/types/field';
|
||||
import { isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import {
|
||||
isFloatFieldCollectionInputInstance,
|
||||
isFloatFieldCollectionInputTemplate,
|
||||
isImageFieldCollectionInputInstance,
|
||||
isImageFieldCollectionInputTemplate,
|
||||
isIntegerFieldCollectionInputInstance,
|
||||
isIntegerFieldCollectionInputTemplate,
|
||||
isStringFieldCollectionInputInstance,
|
||||
isStringFieldCollectionInputTemplate,
|
||||
} from 'features/nodes/types/field';
|
||||
import {
|
||||
resolveNumberFieldCollectionValue,
|
||||
validateImageFieldCollectionValue,
|
||||
validateNumberFieldCollectionValue,
|
||||
validateStringFieldCollectionValue,
|
||||
} from 'features/nodes/types/fieldValidators';
|
||||
import type { InvocationNode } from 'features/nodes/types/invocation';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import type { UpscaleState } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectUpscaleSlice } from 'features/parameters/store/upscaleSlice';
|
||||
import { selectConfigSlice } from 'features/system/store/configSlice';
|
||||
import i18n from 'i18next';
|
||||
import { forEach, upperFirst } from 'lodash-es';
|
||||
import { forEach, groupBy, negate, upperFirst } from 'lodash-es';
|
||||
import { getConnectedEdges } from 'reactflow';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
/**
|
||||
* This file contains selectors and utilities for determining the app is ready to enqueue generations. The handling
|
||||
@@ -61,11 +79,56 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
}
|
||||
|
||||
if (workflowSettings.shouldValidateGraph) {
|
||||
if (!nodes.nodes.length) {
|
||||
const invocationNodes = nodes.nodes.filter(isInvocationNode);
|
||||
const batchNodes = invocationNodes.filter(isBatchNode);
|
||||
const nonBatchNodes = invocationNodes.filter(negate(isBatchNode));
|
||||
|
||||
if (!nonBatchNodes.length) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.noNodesInGraph') });
|
||||
}
|
||||
|
||||
nodes.nodes.forEach((node) => {
|
||||
for (const node of batchNodes) {
|
||||
if (nodes.edges.find((e) => e.source === node.id) === undefined) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.batchNodeNotConnected', { label: node.data.label }) });
|
||||
}
|
||||
}
|
||||
|
||||
if (batchNodes.length > 1) {
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
if (batchGroupId === 'None') {
|
||||
// Ungrouped batch nodes may have differing collection sizes
|
||||
continue;
|
||||
}
|
||||
|
||||
// But grouped batch nodes must have the same collection size
|
||||
const collectionSizes: number[] = [];
|
||||
|
||||
for (const node of batchNodes) {
|
||||
if (node.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(node.data.inputs.images));
|
||||
collectionSizes.push(node.data.inputs.images.value?.length ?? 0);
|
||||
} else if (node.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(node.data.inputs.strings));
|
||||
collectionSizes.push(node.data.inputs.strings.value?.length ?? 0);
|
||||
} else if (node.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(node.data.inputs.floats));
|
||||
collectionSizes.push(node.data.inputs.floats.value?.length ?? 0);
|
||||
} else if (node.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(node.data.inputs.integers));
|
||||
collectionSizes.push(node.data.inputs.integers.value?.length ?? 0);
|
||||
}
|
||||
}
|
||||
|
||||
if (collectionSizes.some((count) => count !== collectionSizes[0])) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.batchNodeCollectionSizeMismatch', { batchGroupId }),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nonBatchNodes.forEach((node) => {
|
||||
if (!isInvocationNode(node)) {
|
||||
return;
|
||||
}
|
||||
@@ -91,45 +154,38 @@ const getReasonsWhyCannotEnqueueWorkflowsTab = (arg: {
|
||||
return;
|
||||
}
|
||||
|
||||
const baseTKeyOptions = {
|
||||
nodeLabel: node.data.label || nodeTemplate.title,
|
||||
fieldLabel: field.label || fieldTemplate.title,
|
||||
};
|
||||
const prefix = `${node.data.label || nodeTemplate.title} -> ${field.label || fieldTemplate.title}`;
|
||||
|
||||
if (fieldTemplate.required && field.value === undefined && !hasConnection) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.missingInputForField', baseTKeyOptions) });
|
||||
return;
|
||||
reasons.push({ prefix, content: i18n.t('parameters.invoke.missingInputForField') });
|
||||
} else if (
|
||||
field.value &&
|
||||
isImageFieldCollectionInputInstance(field) &&
|
||||
isImageFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
// Image collections may have min or max items to validate
|
||||
// TODO(psyche): generalize this to other collection types
|
||||
if (fieldTemplate.minItems !== undefined && fieldTemplate.minItems > 0 && field.value.length === 0) {
|
||||
reasons.push({ content: i18n.t('parameters.invoke.collectionEmpty', baseTKeyOptions) });
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.minItems !== undefined && field.value.length < fieldTemplate.minItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooFewItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
minItems: fieldTemplate.minItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (fieldTemplate.maxItems !== undefined && field.value.length > fieldTemplate.maxItems) {
|
||||
reasons.push({
|
||||
content: i18n.t('parameters.invoke.collectionTooManyItems', {
|
||||
...baseTKeyOptions,
|
||||
size: field.value.length,
|
||||
maxItems: fieldTemplate.maxItems,
|
||||
}),
|
||||
});
|
||||
return;
|
||||
}
|
||||
const errors = validateImageFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isStringFieldCollectionInputInstance(field) &&
|
||||
isStringFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateStringFieldCollectionValue(field.value, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isIntegerFieldCollectionInputInstance(field) &&
|
||||
isIntegerFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
} else if (
|
||||
field.value &&
|
||||
isFloatFieldCollectionInputInstance(field) &&
|
||||
isFloatFieldCollectionInputTemplate(fieldTemplate)
|
||||
) {
|
||||
const errors = validateNumberFieldCollectionValue(field, fieldTemplate);
|
||||
reasons.push(...errors.map((error) => ({ prefix, content: error })));
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -491,17 +547,97 @@ export const selectPromptsCount = createSelector(
|
||||
(params, dynamicPrompts) => (getShouldProcessPrompt(params.positivePrompt) ? dynamicPrompts.prompts.length : 1)
|
||||
);
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(selectNodesSlice, ({ nodes }) =>
|
||||
// The batch size is the product of all batch nodes' collection sizes
|
||||
nodes.filter(isInvocationNode).reduce((batchSize, node) => {
|
||||
if (!isImageFieldCollectionInputInstance(node.data.inputs.images)) {
|
||||
return batchSize;
|
||||
}
|
||||
// If the batch size is not set, default to 1
|
||||
batchSize = batchSize || 1;
|
||||
// Multiply the batch size by the number of images in the batch
|
||||
batchSize = batchSize * (node.data.inputs.images.value?.length ?? 0);
|
||||
const getBatchCollectionSize = (batchNode: InvocationNode) => {
|
||||
if (batchNode.data.type === 'image_batch') {
|
||||
assert(isImageFieldCollectionInputInstance(batchNode.data.inputs.images));
|
||||
return batchNode.data.inputs.images.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'string_batch') {
|
||||
assert(isStringFieldCollectionInputInstance(batchNode.data.inputs.strings));
|
||||
return batchNode.data.inputs.strings.value?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'float_batch') {
|
||||
assert(isFloatFieldCollectionInputInstance(batchNode.data.inputs.floats));
|
||||
return resolveNumberFieldCollectionValue(batchNode.data.inputs.floats)?.length ?? 0;
|
||||
} else if (batchNode.data.type === 'integer_batch') {
|
||||
assert(isIntegerFieldCollectionInputInstance(batchNode.data.inputs.integers));
|
||||
return resolveNumberFieldCollectionValue(batchNode.data.inputs.integers)?.length ?? 0;
|
||||
}
|
||||
return 0;
|
||||
};
|
||||
|
||||
return batchSize;
|
||||
}, 0)
|
||||
const buildSelectGroupBatchSizes = (batchGroupId: string) =>
|
||||
createMemoizedSelector(selectNodesSlice, ({ nodes }) => {
|
||||
return nodes
|
||||
.filter(isInvocationNode)
|
||||
.filter(isBatchNode)
|
||||
.filter((node) => node.data.inputs['batch_group_id']?.value === batchGroupId)
|
||||
.map(getBatchCollectionSize);
|
||||
});
|
||||
|
||||
const selectUngroupedBatchSizes = buildSelectGroupBatchSizes('None');
|
||||
const selectGroup1BatchSizes = buildSelectGroupBatchSizes('Group 1');
|
||||
const selectGroup2BatchSizes = buildSelectGroupBatchSizes('Group 2');
|
||||
const selectGroup3BatchSizes = buildSelectGroupBatchSizes('Group 3');
|
||||
const selectGroup4BatchSizes = buildSelectGroupBatchSizes('Group 4');
|
||||
const selectGroup5BatchSizes = buildSelectGroupBatchSizes('Group 5');
|
||||
|
||||
export const selectWorkflowsBatchSize = createSelector(
|
||||
selectUngroupedBatchSizes,
|
||||
selectGroup1BatchSizes,
|
||||
selectGroup2BatchSizes,
|
||||
selectGroup3BatchSizes,
|
||||
selectGroup4BatchSizes,
|
||||
selectGroup5BatchSizes,
|
||||
(
|
||||
ungroupedBatchSizes,
|
||||
group1BatchSizes,
|
||||
group2BatchSizes,
|
||||
group3BatchSizes,
|
||||
group4BatchSizes,
|
||||
group5BatchSizes
|
||||
): number | 'INVALID' | 'NO_BATCHES' => {
|
||||
// All batch nodes _must_ have a populated collection
|
||||
|
||||
const allBatchSizes = [
|
||||
...ungroupedBatchSizes,
|
||||
...group1BatchSizes,
|
||||
...group2BatchSizes,
|
||||
...group3BatchSizes,
|
||||
...group4BatchSizes,
|
||||
...group5BatchSizes,
|
||||
];
|
||||
|
||||
// There are no batch nodes
|
||||
if (allBatchSizes.length === 0) {
|
||||
return 'NO_BATCHES';
|
||||
}
|
||||
|
||||
// All batch nodes must have a populated collection
|
||||
if (allBatchSizes.some((size) => size === 0)) {
|
||||
return 'INVALID';
|
||||
}
|
||||
|
||||
for (const group of [group1BatchSizes, group2BatchSizes, group3BatchSizes, group4BatchSizes, group5BatchSizes]) {
|
||||
// Ignore groups with no batch nodes
|
||||
if (group.length === 0) {
|
||||
continue;
|
||||
}
|
||||
// Grouped batch nodes must have the same collection size
|
||||
if (group.some((size) => size !== group[0])) {
|
||||
return 'INVALID';
|
||||
}
|
||||
}
|
||||
|
||||
// Total batch size = product of all ungrouped batches and each grouped batch
|
||||
const totalBatchSize = [
|
||||
...ungroupedBatchSizes,
|
||||
// In case of no batch nodes in a group, fall back to 1 for the product calculation
|
||||
group1BatchSizes[0] ?? 1,
|
||||
group2BatchSizes[0] ?? 1,
|
||||
group3BatchSizes[0] ?? 1,
|
||||
group4BatchSizes[0] ?? 1,
|
||||
group5BatchSizes[0] ?? 1,
|
||||
].reduce((acc, size) => acc * size, 1);
|
||||
|
||||
return totalBatchSize;
|
||||
}
|
||||
);
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -189,6 +189,26 @@ def test_cannot_create_bad_batch_items_type(batch_graph):
|
||||
)
|
||||
|
||||
|
||||
def test_number_type_interop(batch_graph):
|
||||
# integers and floats can be mixed, should not throw an error
|
||||
Batch(
|
||||
graph=batch_graph,
|
||||
data=[
|
||||
[
|
||||
BatchDatum(node_path="1", field_name="prompt", items=[1, 1.5]),
|
||||
]
|
||||
],
|
||||
)
|
||||
Batch(
|
||||
graph=batch_graph,
|
||||
data=[
|
||||
[
|
||||
BatchDatum(node_path="1", field_name="prompt", items=[1.5, 1]),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def test_cannot_create_bad_batch_unique_ids(batch_graph):
|
||||
with pytest.raises(ValidationError, match="Each batch data must have unique node_id and field_name"):
|
||||
Batch(
|
||||
|
||||
Reference in New Issue
Block a user