mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-13 21:25:04 -05:00
refactor(ui): dedicated enqueue funcs for each tab
This commit is contained in:
@@ -8,7 +8,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar
|
||||
import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted';
|
||||
import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected';
|
||||
import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload';
|
||||
import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema';
|
||||
import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard';
|
||||
import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard';
|
||||
@@ -20,7 +19,6 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi
|
||||
import type { AppDispatch, RootState } from 'app/store/store';
|
||||
|
||||
import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener';
|
||||
import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale';
|
||||
|
||||
export const listenerMiddleware = createListenerMiddleware();
|
||||
|
||||
@@ -43,8 +41,6 @@ addImageUploadedFulfilledListener(startAppListening);
|
||||
addDeleteBoardAndImagesFulfilledListener(startAppListening);
|
||||
|
||||
// User Invoked
|
||||
addEnqueueRequestedLinear(startAppListening);
|
||||
addEnqueueRequestedUpscale(startAppListening);
|
||||
addAnyEnqueuedListener(startAppListening);
|
||||
addBatchEnqueuedListener(startAppListening);
|
||||
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import {
|
||||
canvasSessionIdCreated,
|
||||
generateSessionIdCreated,
|
||||
selectCanvasSessionId,
|
||||
selectGenerateSessionId,
|
||||
} from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { $canvasManager } from 'features/controlLayers/store/ephemeral';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { selectActiveTab } from 'features/ui/store/uiSelectors';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas');
|
||||
|
||||
export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedCanvas,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
log.debug('Enqueue requested');
|
||||
|
||||
const tab = selectActiveTab(getState());
|
||||
let sessionId = null;
|
||||
if (tab === 'generate') {
|
||||
sessionId = selectGenerateSessionId(getState());
|
||||
if (!sessionId) {
|
||||
dispatch(generateSessionIdCreated());
|
||||
sessionId = selectGenerateSessionId(getState());
|
||||
}
|
||||
} else if (tab === 'canvas') {
|
||||
sessionId = selectCanvasSessionId(getState());
|
||||
if (!sessionId) {
|
||||
dispatch(canvasSessionIdCreated());
|
||||
sessionId = selectCanvasSessionId(getState());
|
||||
}
|
||||
} else {
|
||||
log.warn(`Enqueue requested in unsupported tab ${tab}`);
|
||||
return;
|
||||
}
|
||||
|
||||
const state = getState();
|
||||
const destination = sessionId;
|
||||
assert(destination !== null);
|
||||
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const manager = $canvasManager.get();
|
||||
// assert(manager, 'No canvas manager');
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, manager);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, manager);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, manager);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, manager);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, manager);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, manager);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, manager);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, manager);
|
||||
case 'flux-kontext':
|
||||
return await buildFluxKontextGraph(state, manager);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: tab,
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions)
|
||||
);
|
||||
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -1,44 +0,0 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStartListening } from 'app/store/middleware/listenerMiddleware';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling');
|
||||
|
||||
export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => {
|
||||
startAppListening({
|
||||
actionCreator: enqueueRequestedUpscaling,
|
||||
effect: async (action, { getState, dispatch }) => {
|
||||
const state = getState();
|
||||
const { prepend } = action.payload;
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions));
|
||||
try {
|
||||
await req.unwrap();
|
||||
log.debug(parseify({ batchConfig }), 'Enqueued batch');
|
||||
} catch (error) {
|
||||
log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch');
|
||||
} finally {
|
||||
req.reset();
|
||||
}
|
||||
},
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,143 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate';
|
||||
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
|
||||
import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas');
|
||||
|
||||
const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedCanvas());
|
||||
|
||||
let destination = selectCanvasSessionId(getState());
|
||||
if (!destination) {
|
||||
dispatch(canvasSessionIdCreated());
|
||||
destination = selectCanvasSessionId(getState());
|
||||
}
|
||||
assert(destination !== null);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, canvasManager);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, canvasManager);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, canvasManager);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, canvasManager);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, canvasManager);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, canvasManager);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, canvasManager);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, canvasManager);
|
||||
case 'flux-kontext':
|
||||
return await buildFluxKontextGraph(state, canvasManager);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueCanvas = () => {
|
||||
const store = useAppStore();
|
||||
const canvasManager = useCanvasManagerSafe();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
if (!canvasManager) {
|
||||
log.error('Canvas manager is not available');
|
||||
return;
|
||||
}
|
||||
return enqueueCanvas(store, canvasManager, prepend);
|
||||
},
|
||||
[canvasManager, store]
|
||||
);
|
||||
return enqueue;
|
||||
};
|
||||
@@ -0,0 +1,137 @@
|
||||
import type { AlertStatus } from '@invoke-ai/ui-library';
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError';
|
||||
import { withResult, withResultAsync } from 'common/util/result';
|
||||
import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph';
|
||||
import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph';
|
||||
import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph';
|
||||
import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph';
|
||||
import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph';
|
||||
import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph';
|
||||
import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph';
|
||||
import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph';
|
||||
import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph';
|
||||
import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types';
|
||||
import { toast } from 'features/toast/toast';
|
||||
import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
import { assert, AssertionError } from 'tsafe';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate');
|
||||
|
||||
const enqueueGenerate = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedGenerate());
|
||||
|
||||
let destination = selectGenerateSessionId(getState());
|
||||
if (!destination) {
|
||||
dispatch(generateSessionIdCreated());
|
||||
destination = selectGenerateSessionId(getState());
|
||||
}
|
||||
assert(destination !== null);
|
||||
|
||||
const state = getState();
|
||||
|
||||
const model = state.params.model;
|
||||
assert(model, 'No model found in state');
|
||||
const base = model.base;
|
||||
|
||||
const buildGraphResult = await withResultAsync(async () => {
|
||||
switch (base) {
|
||||
case 'sdxl':
|
||||
return await buildSDXLGraph(state, null);
|
||||
case 'sd-1':
|
||||
case `sd-2`:
|
||||
return await buildSD1Graph(state, null);
|
||||
case `sd-3`:
|
||||
return await buildSD3Graph(state, null);
|
||||
case `flux`:
|
||||
return await buildFLUXGraph(state, null);
|
||||
case 'cogview4':
|
||||
return await buildCogView4Graph(state, null);
|
||||
case 'imagen3':
|
||||
return await buildImagen3Graph(state, null);
|
||||
case 'imagen4':
|
||||
return await buildImagen4Graph(state, null);
|
||||
case 'chatgpt-4o':
|
||||
return await buildChatGPT4oGraph(state, null);
|
||||
case 'flux-kontext':
|
||||
return await buildFluxKontextGraph(state, null);
|
||||
default:
|
||||
assert(false, `No graph builders for base ${base}`);
|
||||
}
|
||||
});
|
||||
|
||||
if (buildGraphResult.isErr()) {
|
||||
let title = 'Failed to build graph';
|
||||
let status: AlertStatus = 'error';
|
||||
let description: string | null = null;
|
||||
if (buildGraphResult.error instanceof AssertionError) {
|
||||
description = extractMessageFromAssertionError(buildGraphResult.error);
|
||||
} else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) {
|
||||
title = 'Unsupported generation mode';
|
||||
description = buildGraphResult.error.message;
|
||||
status = 'warning';
|
||||
}
|
||||
const error = serializeError(buildGraphResult.error);
|
||||
log.error({ error }, 'Failed to build graph');
|
||||
toast({
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value;
|
||||
|
||||
const prepareBatchResult = withResult(() =>
|
||||
prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'canvas',
|
||||
destination,
|
||||
})
|
||||
);
|
||||
|
||||
if (prepareBatchResult.isErr()) {
|
||||
log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch');
|
||||
return;
|
||||
}
|
||||
|
||||
const batchConfig = prepareBatchResult.value;
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, {
|
||||
...enqueueMutationFixedCacheKeyOptions,
|
||||
track: false,
|
||||
})
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueGenerate = () => {
|
||||
const store = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
return enqueueGenerate(store, prepend);
|
||||
},
|
||||
[store]
|
||||
);
|
||||
return enqueue;
|
||||
};
|
||||
@@ -0,0 +1,47 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppStore } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig';
|
||||
import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph';
|
||||
import { useCallback } from 'react';
|
||||
import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue';
|
||||
|
||||
const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling');
|
||||
|
||||
const enqueueUpscaling = async (store: AppStore, prepend: boolean) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedUpscaling());
|
||||
|
||||
const state = getState();
|
||||
|
||||
const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state);
|
||||
|
||||
const batchConfig = prepareLinearUIBatch({
|
||||
state,
|
||||
g,
|
||||
prepend,
|
||||
seedFieldIdentifier,
|
||||
positivePromptFieldIdentifier,
|
||||
origin: 'upscaling',
|
||||
destination: 'gallery',
|
||||
});
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
const enqueueResult = await req.unwrap();
|
||||
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueUpscaling = () => {
|
||||
const store = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean) => {
|
||||
return enqueueUpscaling(store, prepend);
|
||||
},
|
||||
[store]
|
||||
);
|
||||
return enqueue;
|
||||
};
|
||||
@@ -1,4 +1,5 @@
|
||||
import { createAction } from '@reduxjs/toolkit';
|
||||
import type { AppDispatch, AppStore, RootState } from 'app/store/store';
|
||||
import { useAppStore } from 'app/store/storeHooks';
|
||||
import { groupBy } from 'es-toolkit/compat';
|
||||
import {
|
||||
@@ -8,6 +9,7 @@ import {
|
||||
} from 'features/nodes/components/sidePanel/workflow/publish';
|
||||
import { $templates } from 'features/nodes/store/nodesSlice';
|
||||
import { selectNodeData, selectNodesSlice } from 'features/nodes/store/selectors';
|
||||
import type { Templates } from 'features/nodes/store/types';
|
||||
import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation';
|
||||
import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph';
|
||||
import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue';
|
||||
@@ -19,138 +21,160 @@ import { assert } from 'tsafe';
|
||||
|
||||
const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows');
|
||||
|
||||
export const useEnqueueWorkflows = () => {
|
||||
const { getState, dispatch } = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
async (prepend: boolean, isApiValidationRun: boolean) => {
|
||||
dispatch(enqueueRequestedWorkflows());
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const templates = $templates.get();
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
const builtWorkflow = buildWorkflowWithValidation(nodesState);
|
||||
const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise<Batch['data']> => {
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const data: Batch['data'] = [];
|
||||
|
||||
if (builtWorkflow) {
|
||||
// embedded workflows don't have an id
|
||||
delete builtWorkflow.id;
|
||||
}
|
||||
const invocationNodes = nodesState.nodes.filter(isInvocationNode);
|
||||
const batchNodes = invocationNodes.filter(isBatchNode);
|
||||
|
||||
const data: Batch['data'] = [];
|
||||
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
|
||||
const invocationNodes = nodesState.nodes.filter(isInvocationNode);
|
||||
const batchNodes = invocationNodes.filter(isBatchNode);
|
||||
// Then, we will create a batch data collection item for each group
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
|
||||
// Handle zipping batch nodes. First group the batch nodes by their batch_group_id
|
||||
const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value);
|
||||
|
||||
// Then, we will create a batch data collection item for each group
|
||||
for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) {
|
||||
const zippedBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
|
||||
for (const node of batchNodes) {
|
||||
const value = await resolveBatchValue({ nodesState, node, dispatch });
|
||||
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
|
||||
const edgesFromBatch = nodesState.edges.filter(
|
||||
(e) => e.source === node.id && e.sourceHandle === sourceHandle
|
||||
);
|
||||
if (batchGroupId !== 'None') {
|
||||
// If this batch node has a batch_group_id, we will zip the data collection items
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
zippedBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// Otherwise add the data collection items to root of the batch so they are not zipped
|
||||
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
productBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
if (productBatchDataCollectionItems.length > 0) {
|
||||
data.push(productBatchDataCollectionItems);
|
||||
}
|
||||
for (const node of batchNodes) {
|
||||
const value = await resolveBatchValue({ nodesState, node, dispatch });
|
||||
const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value';
|
||||
const edgesFromBatch = nodesState.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle);
|
||||
if (batchGroupId !== 'None') {
|
||||
// If this batch node has a batch_group_id, we will zip the data collection items
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
zippedBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
} else {
|
||||
// Otherwise add the data collection items to root of the batch so they are not zipped
|
||||
const productBatchDataCollectionItems: NonNullable<Batch['data']>[number] = [];
|
||||
for (const edge of edgesFromBatch) {
|
||||
if (!edge.targetHandle) {
|
||||
break;
|
||||
}
|
||||
productBatchDataCollectionItems.push({
|
||||
node_path: edge.target,
|
||||
field_name: edge.targetHandle,
|
||||
items: value,
|
||||
});
|
||||
}
|
||||
if (productBatchDataCollectionItems.length > 0) {
|
||||
data.push(productBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow: builtWorkflow,
|
||||
runs: state.params.iterations,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
prepend,
|
||||
};
|
||||
// Finally, if this batch data collection item has any items, add it to the data array
|
||||
if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) {
|
||||
data.push(zippedBatchDataCollectionItems);
|
||||
}
|
||||
}
|
||||
|
||||
if (isApiValidationRun) {
|
||||
// Derive the input fields from the builder's selected node field elements
|
||||
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
|
||||
const inputs = getPublishInputs(fieldIdentifiers, templates);
|
||||
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => {
|
||||
return {
|
||||
kind: 'input',
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
user_label: label,
|
||||
} satisfies S['FieldIdentifier'];
|
||||
});
|
||||
return data;
|
||||
};
|
||||
|
||||
// Derive the output fields from the builder's selected output node
|
||||
const outputNodeId = $outputNodeId.get();
|
||||
assert(outputNodeId !== null, 'Output node not selected');
|
||||
const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type;
|
||||
const outputNodeTemplate = templates[outputNodeType];
|
||||
assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`);
|
||||
const outputFieldNames = Object.keys(outputNodeTemplate.outputs);
|
||||
const api_output_fields = outputFieldNames.map((fieldName) => {
|
||||
return {
|
||||
kind: 'output',
|
||||
node_id: outputNodeId,
|
||||
field_name: fieldName,
|
||||
user_label: null,
|
||||
} satisfies S['FieldIdentifier'];
|
||||
});
|
||||
const getValidationRunData = (state: RootState, templates: Templates): S['ValidationRunData'] => {
|
||||
const nodesState = selectNodesSlice(state);
|
||||
|
||||
assert(nodesState.id, 'Workflow without ID cannot be used for API validation run');
|
||||
// Derive the input fields from the builder's selected node field elements
|
||||
const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state);
|
||||
const inputs = getPublishInputs(fieldIdentifiers, templates);
|
||||
const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => {
|
||||
return {
|
||||
kind: 'input',
|
||||
node_id: nodeId,
|
||||
field_name: fieldName,
|
||||
user_label: label,
|
||||
} satisfies S['FieldIdentifier'];
|
||||
});
|
||||
|
||||
batchConfig.validation_run_data = {
|
||||
workflow_id: nodesState.id,
|
||||
input_fields: api_input_fields,
|
||||
output_fields: api_output_fields,
|
||||
};
|
||||
// Derive the output fields from the builder's selected output node
|
||||
const outputNodeId = $outputNodeId.get();
|
||||
assert(outputNodeId !== null, 'Output node not selected');
|
||||
const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type;
|
||||
const outputNodeTemplate = templates[outputNodeType];
|
||||
assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`);
|
||||
const outputFieldNames = Object.keys(outputNodeTemplate.outputs);
|
||||
const api_output_fields = outputFieldNames.map((fieldName) => {
|
||||
return {
|
||||
kind: 'output',
|
||||
node_id: outputNodeId,
|
||||
field_name: fieldName,
|
||||
user_label: null,
|
||||
} satisfies S['FieldIdentifier'];
|
||||
});
|
||||
|
||||
// If the batch is an API validation run, we only want to run it once
|
||||
batchConfig.batch.runs = 1;
|
||||
}
|
||||
assert(nodesState.id, 'Workflow without ID cannot be used for API validation run');
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
return {
|
||||
workflow_id: nodesState.id,
|
||||
input_fields: api_input_fields,
|
||||
output_fields: api_output_fields,
|
||||
};
|
||||
};
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
return { batchConfig, enqueueResult };
|
||||
const enqueueWorkflows = async (
|
||||
store: AppStore,
|
||||
templates: Templates,
|
||||
prepend: boolean,
|
||||
isApiValidationRun: boolean
|
||||
) => {
|
||||
const { dispatch, getState } = store;
|
||||
|
||||
dispatch(enqueueRequestedWorkflows());
|
||||
const state = getState();
|
||||
const nodesState = selectNodesSlice(state);
|
||||
const graph = buildNodesGraph(state, templates);
|
||||
const workflow = buildWorkflowWithValidation(nodesState);
|
||||
|
||||
if (workflow) {
|
||||
// embedded workflows don't have an id
|
||||
delete workflow.id;
|
||||
}
|
||||
|
||||
const runs = state.params.iterations;
|
||||
const data = await getBatchDataForWorkflowGeneration(state, dispatch);
|
||||
|
||||
const batchConfig: EnqueueBatchArg = {
|
||||
batch: {
|
||||
graph,
|
||||
workflow,
|
||||
runs,
|
||||
origin: 'workflows',
|
||||
destination: 'gallery',
|
||||
data,
|
||||
},
|
||||
[dispatch, getState]
|
||||
prepend,
|
||||
};
|
||||
|
||||
if (isApiValidationRun) {
|
||||
batchConfig.validation_run_data = getValidationRunData(state, templates);
|
||||
|
||||
// If the batch is an API validation run, we only want to run it once
|
||||
batchConfig.batch.runs = 1;
|
||||
}
|
||||
|
||||
const req = dispatch(
|
||||
queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false })
|
||||
);
|
||||
|
||||
const enqueueResult = await req.unwrap();
|
||||
return { batchConfig, enqueueResult };
|
||||
};
|
||||
|
||||
export const useEnqueueWorkflows = () => {
|
||||
const store = useAppStore();
|
||||
const enqueue = useCallback(
|
||||
(prepend: boolean, isApiValidationRun: boolean) => {
|
||||
return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun);
|
||||
},
|
||||
[store]
|
||||
);
|
||||
|
||||
return enqueue;
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import { useStore } from '@nanostores/react';
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { enqueueRequestedCanvas } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear';
|
||||
import { enqueueRequestedUpscaling } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale';
|
||||
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
import { withResultAsync } from 'common/util/result';
|
||||
import { parseify } from 'common/util/serialize';
|
||||
import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked';
|
||||
import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows';
|
||||
import { $isReadyToEnqueue } from 'features/queue/store/readiness';
|
||||
@@ -15,15 +12,21 @@ import { useCallback } from 'react';
|
||||
import { serializeError } from 'serialize-error';
|
||||
import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue';
|
||||
|
||||
import { useEnqueueCanvas } from './useEnqueueCanvas';
|
||||
import { useEnqueueGenerate } from './useEnqueueGenerate';
|
||||
import { useEnqueueUpscaling } from './useEnqueueUpscaling';
|
||||
|
||||
const log = logger('generation');
|
||||
|
||||
export const useInvoke = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const ctx = useAutoLayoutContextSafe();
|
||||
const tabName = useAppSelector(selectActiveTab);
|
||||
const isReady = useStore($isReadyToEnqueue);
|
||||
const isLocked = useIsWorkflowEditorLocked();
|
||||
const enqueueWorkflows = useEnqueueWorkflows();
|
||||
const enqueueCanvas = useEnqueueCanvas();
|
||||
const enqueueGenerate = useEnqueueGenerate();
|
||||
const enqueueUpscaling = useEnqueueUpscaling();
|
||||
|
||||
const [_, { isLoading }] = useEnqueueBatchMutation(enqueueMutationFixedCacheKeyOptions);
|
||||
|
||||
@@ -33,28 +36,26 @@ export const useInvoke = () => {
|
||||
return;
|
||||
}
|
||||
|
||||
if (tabName === 'workflows') {
|
||||
const result = await withResultAsync(() => enqueueWorkflows(prepend, isApiValidationRun));
|
||||
if (result.isErr()) {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
} else {
|
||||
log.debug(parseify(result.value), 'Enqueued batch');
|
||||
const result = await withResultAsync(async () => {
|
||||
switch (tabName) {
|
||||
case 'workflows':
|
||||
return await enqueueWorkflows(prepend, isApiValidationRun);
|
||||
case 'canvas':
|
||||
return await enqueueCanvas(prepend);
|
||||
case 'generate':
|
||||
return await enqueueGenerate(prepend);
|
||||
case 'upscaling':
|
||||
return await enqueueUpscaling(prepend);
|
||||
default:
|
||||
throw new Error(`No enqueue handler for tab: ${tabName}`);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (tabName === 'upscaling') {
|
||||
dispatch(enqueueRequestedUpscaling({ prepend }));
|
||||
return;
|
||||
if (result.isErr()) {
|
||||
log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch');
|
||||
}
|
||||
|
||||
if (tabName === 'canvas' || tabName === 'generate') {
|
||||
dispatch(enqueueRequestedCanvas({ prepend }));
|
||||
return;
|
||||
}
|
||||
|
||||
// Else we are not on a generation tab and should not queue
|
||||
},
|
||||
[dispatch, enqueueWorkflows, isReady, tabName]
|
||||
[enqueueCanvas, enqueueGenerate, enqueueUpscaling, enqueueWorkflows, isReady, tabName]
|
||||
);
|
||||
|
||||
const enqueueBack = useCallback(() => {
|
||||
|
||||
Reference in New Issue
Block a user