refactor(ui): dedicated enqueue funcs for each tab

This commit is contained in:
psychedelicious
2025-06-30 15:33:55 +10:00
parent b113c57fc4
commit fb883d63aa
8 changed files with 490 additions and 340 deletions

View File

@@ -8,7 +8,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
@@ -20,7 +19,6 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
import type { AppDispatch, RootState } from 'app/store/store';
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
export const listenerMiddleware = createListenerMiddleware();
@@ -43,8 +41,6 @@ addImageUploadedFulfilledListener(startAppListening);
addDeleteBoardAndImagesFulfilledListener(startAppListening);
// User Invoked
addEnqueueRequestedLinear(startAppListening);
addEnqueueRequestedUpscale(startAppListening);
addAnyEnqueuedListener(startAppListening);
addBatchEnqueuedListener(startAppListening);

View File

@@ -1,154 +0,0 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import {
canvasSessionIdCreated,
generateSessionIdCreated,
selectCanvasSessionId,
selectGenerateSessionId,
} from 'features/controlLayers/store/canvasStagingAreaSlice';
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { selectActiveTab } from 'features/ui/store/uiSelectors';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedCanvas,
effect: async (action, { getState, dispatch }) => {
log.debug('Enqueue requested');
const tab = selectActiveTab(getState());
let sessionId = null;
if (tab === 'generate') {
sessionId = selectGenerateSessionId(getState());
if (!sessionId) {
dispatch(generateSessionIdCreated());
sessionId = selectGenerateSessionId(getState());
}
} else if (tab === 'canvas') {
sessionId = selectCanvasSessionId(getState());
if (!sessionId) {
dispatch(canvasSessionIdCreated());
sessionId = selectCanvasSessionId(getState());
}
} else {
log.warn(`Enqueue requested in unsupported tab ${tab}`);
return;
}
const state = getState();
const destination = sessionId;
assert(destination !== null);
const { prepend } = action.payload;
const manager = $canvasManager.get();
// assert(manager, 'No canvas manager');
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, manager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, manager);
case `sd-3`:
return await buildSD3Graph(state, manager);
case `flux`:
return await buildFLUXGraph(state, manager);
case 'cogview4':
return await buildCogView4Graph(state, manager);
case 'imagen3':
return await buildImagen3Graph(state, manager);
case 'imagen4':
return await buildImagen4Graph(state, manager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, manager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, manager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: tab,
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
);
try {
await req.unwrap();
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -1,44 +0,0 @@
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
import { parseify } from 'common/util/serialize';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
const log = logger('generation');
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
startAppListening({
actionCreator: enqueueRequestedUpscaling,
effect: async (action, { getState, dispatch }) => {
const state = getState();
const { prepend } = action.payload;
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
try {
await req.unwrap();
log.debug(parseify({ batchConfig }), 'Enqueued batch');
} catch (error) {
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
} finally {
req.reset();
}
},
});
};

View File

@@ -0,0 +1,143 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedCanvas());
let destination = selectCanvasSessionId(getState());
if (!destination) {
dispatch(canvasSessionIdCreated());
destination = selectCanvasSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, canvasManager);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, canvasManager);
case `sd-3`:
return await buildSD3Graph(state, canvasManager);
case `flux`:
return await buildFLUXGraph(state, canvasManager);
case 'cogview4':
return await buildCogView4Graph(state, canvasManager);
case 'imagen3':
return await buildImagen3Graph(state, canvasManager);
case 'imagen4':
return await buildImagen4Graph(state, canvasManager);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, canvasManager);
case 'flux-kontext':
return await buildFluxKontextGraph(state, canvasManager);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueCanvas = () => {
const store = useAppStore();
const canvasManager = useCanvasManagerSafe();
const enqueue = useCallback(
(prepend: boolean) => {
if (!canvasManager) {
log.error('Canvas manager is not available');
return;
}
return enqueueCanvas(store, canvasManager, prepend);
},
[canvasManager, store]
);
return enqueue;
};

View File

@@ -0,0 +1,137 @@
import type { AlertStatus } from '@invoke-ai/ui-library';
import { createAction } from '@reduxjs/toolkit';
import { logger } from 'app/logging/logger';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
import { withResult, withResultAsync } from 'common/util/result';
import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
import { toast } from 'features/toast/toast';
import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
import { assert, AssertionError } from 'tsafe';
const log = logger('generation');
export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate');
const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedGenerate());
let destination = selectGenerateSessionId(getState());
if (!destination) {
dispatch(generateSessionIdCreated());
destination = selectGenerateSessionId(getState());
}
assert(destination !== null);
const state = getState();
const model = state.params.model;
assert(model, 'No model found in state');
const base = model.base;
const buildGraphResult = await withResultAsync(async () => {
switch (base) {
case 'sdxl':
return await buildSDXLGraph(state, null);
case 'sd-1':
case `sd-2`:
return await buildSD1Graph(state, null);
case `sd-3`:
return await buildSD3Graph(state, null);
case `flux`:
return await buildFLUXGraph(state, null);
case 'cogview4':
return await buildCogView4Graph(state, null);
case 'imagen3':
return await buildImagen3Graph(state, null);
case 'imagen4':
return await buildImagen4Graph(state, null);
case 'chatgpt-4o':
return await buildChatGPT4oGraph(state, null);
case 'flux-kontext':
return await buildFluxKontextGraph(state, null);
default:
assert(false, `No graph builders for base ${base}`);
}
});
if (buildGraphResult.isErr()) {
let title = 'Failed to build graph';
let status: AlertStatus = 'error';
let description: string | null = null;
if (buildGraphResult.error instanceof AssertionError) {
description = extractMessageFromAssertionError(buildGraphResult.error);
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
title = 'Unsupported generation mode';
description = buildGraphResult.error.message;
status = 'warning';
}
const error = serializeError(buildGraphResult.error);
log.error({ error }, 'Failed to build graph');
toast({
status,
title,
description,
});
return;
}
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
const prepareBatchResult = withResult(() =>
prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'canvas',
destination,
})
);
if (prepareBatchResult.isErr()) {
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
return;
}
const batchConfig = prepareBatchResult.value;
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
...enqueueMutationFixedCacheKeyOptions,
track: false,
})
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueGenerate = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueGenerate(store, prepend);
},
[store]
);
return enqueue;
};

View File

@@ -0,0 +1,47 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppStore } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
import { useCallback } from 'react';
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedUpscaling());
const state = getState();
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
const batchConfig = prepareLinearUIBatch({
state,
g,
prepend,
seedFieldIdentifier,
positivePromptFieldIdentifier,
origin: 'upscaling',
destination: 'gallery',
});
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueUpscaling = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean) => {
return enqueueUpscaling(store, prepend);
},
[store]
);
return enqueue;
};

View File

@@ -1,4 +1,5 @@
import { createAction } from '@reduxjs/toolkit';
import type { AppDispatch, AppStore, RootState } from 'app/store/store';
import { useAppStore } from 'app/store/storeHooks';
import { groupBy } from 'es-toolkit/compat';
import {
@@ -8,6 +9,7 @@ import {
} from 'features/nodes/components/sidePanel/workflow/publish';
import { $templates } from 'features/nodes/store/nodesSlice';
import { selectNodeData, selectNodesSlice } from 'features/nodes/store/selectors';
import type { Templates } from 'features/nodes/store/types';
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
@@ -19,138 +21,160 @@ import { assert } from 'tsafe';
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
export const useEnqueueWorkflows = () => {
const { getState, dispatch } = useAppStore();
const enqueue = useCallback(
async (prepend: boolean, isApiValidationRun: boolean) => {
dispatch(enqueueRequestedWorkflows());
const state = getState();
const nodesState = selectNodesSlice(state);
const templates = $templates.get();
const graph = buildNodesGraph(state, templates);
const builtWorkflow = buildWorkflowWithValidation(nodesState);
const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise<Batch['data']> => {
const nodesState = selectNodesSlice(state);
const data: Batch['data'] = [];
if (builtWorkflow) {
// embedded workflows don't have an id
delete builtWorkflow.id;
}
const invocationNodes = nodesState.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
const data: Batch['data'] = [];
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
const invocationNodes = nodesState.nodes.filter(isInvocationNode);
const batchNodes = invocationNodes.filter(isBatchNode);
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
// Then, we will create a batch data collection item for each group
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const node of batchNodes) {
const value = await resolveBatchValue({ nodesState, node, dispatch });
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
const edgesFromBatch = nodesState.edges.filter(
(e) => e.source === node.id && e.sourceHandle === sourceHandle
);
if (batchGroupId !== 'None') {
// If this batch node has a batch_group_id, we will zip the data collection items
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
} else {
// Otherwise add the data collection items to root of the batch so they are not zipped
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
productBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
for (const node of batchNodes) {
const value = await resolveBatchValue({ nodesState, node, dispatch });
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
const edgesFromBatch = nodesState.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle);
if (batchGroupId !== 'None') {
// If this batch node has a batch_group_id, we will zip the data collection items
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
zippedBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
} else {
// Otherwise add the data collection items to root of the batch so they are not zipped
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
for (const edge of edgesFromBatch) {
if (!edge.targetHandle) {
break;
}
productBatchDataCollectionItems.push({
node_path: edge.target,
field_name: edge.targetHandle,
items: value,
});
}
if (productBatchDataCollectionItems.length > 0) {
data.push(productBatchDataCollectionItems);
}
}
}
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow: builtWorkflow,
runs: state.params.iterations,
origin: 'workflows',
destination: 'gallery',
data,
},
prepend,
};
// Finally, if this batch data collection item has any items, add it to the data array
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
data.push(zippedBatchDataCollectionItems);
}
}
if (isApiValidationRun) {
// Derive the input fields from the builder's selected node field elements
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
const inputs = getPublishInputs(fieldIdentifiers, templates);
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => {
return {
kind: 'input',
node_id: nodeId,
field_name: fieldName,
user_label: label,
} satisfies S['FieldIdentifier'];
});
return data;
};
// Derive the output fields from the builder's selected output node
const outputNodeId = $outputNodeId.get();
assert(outputNodeId !== null, 'Output node not selected');
const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type;
const outputNodeTemplate = templates[outputNodeType];
assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`);
const outputFieldNames = Object.keys(outputNodeTemplate.outputs);
const api_output_fields = outputFieldNames.map((fieldName) => {
return {
kind: 'output',
node_id: outputNodeId,
field_name: fieldName,
user_label: null,
} satisfies S['FieldIdentifier'];
});
const getValidationRunData = (state: RootState, templates: Templates): S['ValidationRunData'] => {
const nodesState = selectNodesSlice(state);
assert(nodesState.id, 'Workflow without ID cannot be used for API validation run');
// Derive the input fields from the builder's selected node field elements
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
const inputs = getPublishInputs(fieldIdentifiers, templates);
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => {
return {
kind: 'input',
node_id: nodeId,
field_name: fieldName,
user_label: label,
} satisfies S['FieldIdentifier'];
});
batchConfig.validation_run_data = {
workflow_id: nodesState.id,
input_fields: api_input_fields,
output_fields: api_output_fields,
};
// Derive the output fields from the builder's selected output node
const outputNodeId = $outputNodeId.get();
assert(outputNodeId !== null, 'Output node not selected');
const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type;
const outputNodeTemplate = templates[outputNodeType];
assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`);
const outputFieldNames = Object.keys(outputNodeTemplate.outputs);
const api_output_fields = outputFieldNames.map((fieldName) => {
return {
kind: 'output',
node_id: outputNodeId,
field_name: fieldName,
user_label: null,
} satisfies S['FieldIdentifier'];
});
// If the batch is an API validation run, we only want to run it once
batchConfig.batch.runs = 1;
}
assert(nodesState.id, 'Workflow without ID cannot be used for API validation run');
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
return {
workflow_id: nodesState.id,
input_fields: api_input_fields,
output_fields: api_output_fields,
};
};
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
const enqueueWorkflows = async (
store: AppStore,
templates: Templates,
prepend: boolean,
isApiValidationRun: boolean
) => {
const { dispatch, getState } = store;
dispatch(enqueueRequestedWorkflows());
const state = getState();
const nodesState = selectNodesSlice(state);
const graph = buildNodesGraph(state, templates);
const workflow = buildWorkflowWithValidation(nodesState);
if (workflow) {
// embedded workflows don't have an id
delete workflow.id;
}
const runs = state.params.iterations;
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
const batchConfig: EnqueueBatchArg = {
batch: {
graph,
workflow,
runs,
origin: 'workflows',
destination: 'gallery',
data,
},
[dispatch, getState]
prepend,
};
if (isApiValidationRun) {
batchConfig.validation_run_data = getValidationRunData(state, templates);
// If the batch is an API validation run, we only want to run it once
batchConfig.batch.runs = 1;
}
const req = dispatch(
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
);
const enqueueResult = await req.unwrap();
return { batchConfig, enqueueResult };
};
export const useEnqueueWorkflows = () => {
const store = useAppStore();
const enqueue = useCallback(
(prepend: boolean, isApiValidationRun: boolean) => {
return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun);
},
[store]
);
return enqueue;

View File

@@ -1,10 +1,7 @@
import { useStore } from '@nanostores/react';
import { logger } from 'app/logging/logger';
import { enqueueRequestedCanvas } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
import { enqueueRequestedUpscaling } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useAppSelector } from 'app/store/storeHooks';
import { withResultAsync } from 'common/util/result';
import { parseify } from 'common/util/serialize';
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
@@ -15,15 +12,21 @@ import { useCallback } from 'react';
import { serializeError } from 'serialize-error';
import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue';
import { useEnqueueCanvas } from './useEnqueueCanvas';
import { useEnqueueGenerate } from './useEnqueueGenerate';
import { useEnqueueUpscaling } from './useEnqueueUpscaling';
const log = logger('generation');
export const useInvoke = () => {
const dispatch = useAppDispatch();
const ctx = useAutoLayoutContextSafe();
const tabName = useAppSelector(selectActiveTab);
const isReady = useStore($isReadyToEnqueue);
const isLocked = useIsWorkflowEditorLocked();
const enqueueWorkflows = useEnqueueWorkflows();
const enqueueCanvas = useEnqueueCanvas();
const enqueueGenerate = useEnqueueGenerate();
const enqueueUpscaling = useEnqueueUpscaling();
const [_, { isLoading }] = useEnqueueBatchMutation(enqueueMutationFixedCacheKeyOptions);
@@ -33,28 +36,26 @@ export const useInvoke = () => {
return;
}
if (tabName === 'workflows') {
const result = await withResultAsync(() => enqueueWorkflows(prepend, isApiValidationRun));
if (result.isErr()) {
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
} else {
log.debug(parseify(result.value), 'Enqueued batch');
const result = await withResultAsync(async () => {
switch (tabName) {
case 'workflows':
return await enqueueWorkflows(prepend, isApiValidationRun);
case 'canvas':
return await enqueueCanvas(prepend);
case 'generate':
return await enqueueGenerate(prepend);
case 'upscaling':
return await enqueueUpscaling(prepend);
default:
throw new Error(`No enqueue handler for tab: ${tabName}`);
}
}
});
if (tabName === 'upscaling') {
dispatch(enqueueRequestedUpscaling({ prepend }));
return;
if (result.isErr()) {
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
}
if (tabName === 'canvas' || tabName === 'generate') {
dispatch(enqueueRequestedCanvas({ prepend }));
return;
}
// Else we are not on a generation tab and should not queue
},
[dispatch, enqueueWorkflows, isReady, tabName]
[enqueueCanvas, enqueueGenerate, enqueueUpscaling, enqueueWorkflows, isReady, tabName]
);
const enqueueBack = useCallback(() => {