feat(ui): support zipped batch nodes

This commit is contained in:
psychedelicious
2025-01-14 11:37:06 +11:00
parent 2e4110a29a
commit edff16124f

View File

@@ -10,9 +10,10 @@ import {
isStringFieldCollectionInputInstance,
} from 'features/nodes/types/field';
import type { InvocationNodeEdge } from 'features/nodes/types/invocation';
import { isInvocationNode } from 'features/nodes/types/invocation';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { buildWorkflowWithValidation } from 'features/nodes/util/workflow/buildWorkflow';
import { groupBy } from 'lodash-es';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import type { Batch, BatchConfig } from 'services/api/types';
@@ -40,92 +41,103 @@ export const addEnqueueRequestedNodes = (startAppListening: AppStartListening) =
const data: Batch['data'] = [];
const addBatchDataCollectionItem = (edges: InvocationNodeEdge[], items?: ImageField[] | string[] | number[]) => {
const batchNodes = nodes.nodes.filter(isInvocationNode).filter(isBatchNode);
// Handle zipping batch nodes. First group the batch nodes by their link_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['link_id']?.value);
// Then, we will create a batch data collection item for each group
for (const [_linkId, batchNodes] of Object.entries(groupedBatchNodes)) {
const batchDataCollectionItem: NonNullable<Batch['data']>[number] = [];
for (const edge of edges) {
if (!edge.targetHandle) {
const addBatchDataCollectionItem = (
edges: InvocationNodeEdge[],
items?: ImageField[] | string[] | number[]
) => {
for (const edge of edges) {
if (!edge.targetHandle) {
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items,
});
}
};
// Grab image batch nodes for special handling
const imageBatchNodes = batchNodes.filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
// Satisfy TS
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
batchDataCollectionItem.push({
node_path: edge.target,
field_name: edge.targetHandle,
items,
});
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
addBatchDataCollectionItem(edgesFromImageBatch, images.value);
}
// Grab string batch nodes for special handling
const stringBatchNodes = batchNodes.filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, strings.value);
}
// Grab integer batch nodes for special handling
const integerBatchNodes = batchNodes.filter((node) => node.data.type === 'integer_batch');
for (const node of integerBatchNodes) {
// Satisfy TS
const integers = node.data.inputs['integers'];
if (!isIntegerFieldCollectionInputInstance(integers)) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
break;
}
if (!integers.value) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, integers.value);
}
// Grab float batch nodes for special handling
const floatBatchNodes = batchNodes.filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isFloatFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
break;
}
if (!floats.value) {
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, floats.value);
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchDataCollectionItem.length > 0) {
data.push(batchDataCollectionItem);
}
};
// Grab image batch nodes for special handling
const imageBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'image_batch');
for (const node of imageBatchNodes) {
// Satisfy TS
const images = node.data.inputs['images'];
if (!isImageFieldCollectionInputInstance(images)) {
log.warn({ nodeId: node.id }, 'Image batch images field is not an image collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromImageBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'image');
addBatchDataCollectionItem(edgesFromImageBatch, images.value);
}
// Grab string batch nodes for special handling
const stringBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'string_batch');
for (const node of stringBatchNodes) {
// Satisfy TS
const strings = node.data.inputs['strings'];
if (!isStringFieldCollectionInputInstance(strings)) {
log.warn({ nodeId: node.id }, 'String batch strings field is not a string collection');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, strings.value);
}
// Grab integer batch nodes for special handling
const integerBatchNodes = nodes.nodes
.filter(isInvocationNode)
.filter((node) => node.data.type === 'integer_batch');
for (const node of integerBatchNodes) {
// Satisfy TS
const integers = node.data.inputs['integers'];
if (!isIntegerFieldCollectionInputInstance(integers)) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is not an integer collection');
break;
}
if (!integers.value) {
log.warn({ nodeId: node.id }, 'Integer batch integers field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, integers.value);
}
// Grab float batch nodes for special handling
const floatBatchNodes = nodes.nodes.filter(isInvocationNode).filter((node) => node.data.type === 'float_batch');
for (const node of floatBatchNodes) {
// Satisfy TS
const floats = node.data.inputs['floats'];
if (!isFloatFieldCollectionInputInstance(floats)) {
log.warn({ nodeId: node.id }, 'Float batch floats field is not a float collection');
break;
}
if (!floats.value) {
log.warn({ nodeId: node.id }, 'Float batch floats field is empty');
break;
}
// Find outgoing edges from the batch node, we will remove these from the graph and create batch data collection items from them instead
const edgesFromStringBatch = nodes.edges.filter((e) => e.source === node.id && e.sourceHandle === 'value');
addBatchDataCollectionItem(edgesFromStringBatch, floats.value);
}
const batchConfig: BatchConfig = {