From fb883d63aa48ed0482c691e368a83d210c34e357 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 30 Jun 2025 15:33:55 +1000 Subject: [PATCH] refactor(ui): dedicated enqueue funcs for each tab --- .../middleware/listenerMiddleware/index.ts | 4 - .../listeners/enqueueRequestedLinear.ts | 154 ----------- .../listeners/enqueueRequestedUpscale.ts | 44 --- .../features/queue/hooks/useEnqueueCanvas.ts | 143 ++++++++++ .../queue/hooks/useEnqueueGenerate.ts | 137 ++++++++++ .../queue/hooks/useEnqueueUpscaling.ts | 47 ++++ .../queue/hooks/useEnqueueWorkflows.ts | 254 ++++++++++-------- .../web/src/features/queue/hooks/useInvoke.ts | 47 ++-- 8 files changed, 490 insertions(+), 340 deletions(-) delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts delete mode 100644 invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale.ts create mode 100644 invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts create mode 100644 invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts create mode 100644 invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts index a99e350a0a..d2253a9421 100644 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts +++ b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/index.ts @@ -8,7 +8,6 @@ import { addBatchEnqueuedListener } from 'app/store/middleware/listenerMiddlewar import { addDeleteBoardAndImagesFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/boardAndImagesDeleted'; import { addBoardIdSelectedListener } from 'app/store/middleware/listenerMiddleware/listeners/boardIdSelected'; import { addBulkDownloadListeners } from 'app/store/middleware/listenerMiddleware/listeners/bulkDownload'; -import { addEnqueueRequestedLinear } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear'; import { addGetOpenAPISchemaListener } from 'app/store/middleware/listenerMiddleware/listeners/getOpenAPISchema'; import { addImageAddedToBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageAddedToBoard'; import { addImageRemovedFromBoardFulfilledListener } from 'app/store/middleware/listenerMiddleware/listeners/imageRemovedFromBoard'; @@ -20,7 +19,6 @@ import { addSocketConnectedEventListener } from 'app/store/middleware/listenerMi import type { AppDispatch, RootState } from 'app/store/store'; import { addArchivedOrDeletedBoardListener } from './listeners/addArchivedOrDeletedBoardListener'; -import { addEnqueueRequestedUpscale } from './listeners/enqueueRequestedUpscale'; export const listenerMiddleware = createListenerMiddleware(); @@ -43,8 +41,6 @@ addImageUploadedFulfilledListener(startAppListening); addDeleteBoardAndImagesFulfilledListener(startAppListening); // User Invoked -addEnqueueRequestedLinear(startAppListening); -addEnqueueRequestedUpscale(startAppListening); addAnyEnqueuedListener(startAppListening); addBatchEnqueuedListener(startAppListening); diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts deleted file mode 100644 index b39530030e..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear.ts +++ /dev/null @@ -1,154 +0,0 @@ -import type { AlertStatus } from '@invoke-ai/ui-library'; -import { createAction } from '@reduxjs/toolkit'; -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; -import { withResult, withResultAsync } from 'common/util/result'; -import { parseify } from 'common/util/serialize'; -import { - canvasSessionIdCreated, - generateSessionIdCreated, - selectCanvasSessionId, - selectGenerateSessionId, -} from 'features/controlLayers/store/canvasStagingAreaSlice'; -import { $canvasManager } from 'features/controlLayers/store/ephemeral'; -import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; -import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; -import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; -import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; -import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; -import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; -import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; -import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; -import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; -import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; -import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; -import { toast } from 'features/toast/toast'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; -import { serializeError } from 'serialize-error'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; -import { assert, AssertionError } from 'tsafe'; - -const log = logger('generation'); - -export const enqueueRequestedCanvas = createAction<{ prepend: boolean }>('app/enqueueRequestedCanvas'); - -export const addEnqueueRequestedLinear = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: enqueueRequestedCanvas, - effect: async (action, { getState, dispatch }) => { - log.debug('Enqueue requested'); - - const tab = selectActiveTab(getState()); - let sessionId = null; - if (tab === 'generate') { - sessionId = selectGenerateSessionId(getState()); - if (!sessionId) { - dispatch(generateSessionIdCreated()); - sessionId = selectGenerateSessionId(getState()); - } - } else if (tab === 'canvas') { - sessionId = selectCanvasSessionId(getState()); - if (!sessionId) { - dispatch(canvasSessionIdCreated()); - sessionId = selectCanvasSessionId(getState()); - } - } else { - log.warn(`Enqueue requested in unsupported tab ${tab}`); - return; - } - - const state = getState(); - const destination = sessionId; - assert(destination !== null); - - const { prepend } = action.payload; - - const manager = $canvasManager.get(); - // assert(manager, 'No canvas manager'); - - const model = state.params.model; - assert(model, 'No model found in state'); - const base = model.base; - - const buildGraphResult = await withResultAsync(async () => { - switch (base) { - case 'sdxl': - return await buildSDXLGraph(state, manager); - case 'sd-1': - case `sd-2`: - return await buildSD1Graph(state, manager); - case `sd-3`: - return await buildSD3Graph(state, manager); - case `flux`: - return await buildFLUXGraph(state, manager); - case 'cogview4': - return await buildCogView4Graph(state, manager); - case 'imagen3': - return await buildImagen3Graph(state, manager); - case 'imagen4': - return await buildImagen4Graph(state, manager); - case 'chatgpt-4o': - return await buildChatGPT4oGraph(state, manager); - case 'flux-kontext': - return await buildFluxKontextGraph(state, manager); - default: - assert(false, `No graph builders for base ${base}`); - } - }); - - if (buildGraphResult.isErr()) { - let title = 'Failed to build graph'; - let status: AlertStatus = 'error'; - let description: string | null = null; - if (buildGraphResult.error instanceof AssertionError) { - description = extractMessageFromAssertionError(buildGraphResult.error); - } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { - title = 'Unsupported generation mode'; - description = buildGraphResult.error.message; - status = 'warning'; - } - const error = serializeError(buildGraphResult.error); - log.error({ error }, 'Failed to build graph'); - toast({ - status, - title, - description, - }); - return; - } - - const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value; - - const prepareBatchResult = withResult(() => - prepareLinearUIBatch({ - state, - g, - prepend, - seedFieldIdentifier, - positivePromptFieldIdentifier, - origin: tab, - destination, - }) - ); - - if (prepareBatchResult.isErr()) { - log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); - return; - } - - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(prepareBatchResult.value, enqueueMutationFixedCacheKeyOptions) - ); - - try { - await req.unwrap(); - log.debug(parseify({ batchConfig: prepareBatchResult.value }), 'Enqueued batch'); - } catch (error) { - log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch'); - } finally { - req.reset(); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale.ts b/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale.ts deleted file mode 100644 index a47be0bcc5..0000000000 --- a/invokeai/frontend/web/src/app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale.ts +++ /dev/null @@ -1,44 +0,0 @@ -import { createAction } from '@reduxjs/toolkit'; -import { logger } from 'app/logging/logger'; -import type { AppStartListening } from 'app/store/middleware/listenerMiddleware'; -import { parseify } from 'common/util/serialize'; -import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; -import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph'; -import { serializeError } from 'serialize-error'; -import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; - -const log = logger('generation'); - -export const enqueueRequestedUpscaling = createAction<{ prepend: boolean }>('app/enqueueRequestedUpscaling'); - -export const addEnqueueRequestedUpscale = (startAppListening: AppStartListening) => { - startAppListening({ - actionCreator: enqueueRequestedUpscaling, - effect: async (action, { getState, dispatch }) => { - const state = getState(); - const { prepend } = action.payload; - - const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state); - - const batchConfig = prepareLinearUIBatch({ - state, - g, - prepend, - seedFieldIdentifier, - positivePromptFieldIdentifier, - origin: 'upscaling', - destination: 'gallery', - }); - - const req = dispatch(queueApi.endpoints.enqueueBatch.initiate(batchConfig, enqueueMutationFixedCacheKeyOptions)); - try { - await req.unwrap(); - log.debug(parseify({ batchConfig }), 'Enqueued batch'); - } catch (error) { - log.error({ error: serializeError(error as Error) }, 'Failed to enqueue batch'); - } finally { - req.reset(); - } - }, - }); -}; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts new file mode 100644 index 0000000000..6032a52d46 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts @@ -0,0 +1,143 @@ +import type { AlertStatus } from '@invoke-ai/ui-library'; +import { createAction } from '@reduxjs/toolkit'; +import { logger } from 'app/logging/logger'; +import type { AppStore } from 'app/store/store'; +import { useAppStore } from 'app/store/storeHooks'; +import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; +import { withResult, withResultAsync } from 'common/util/result'; +import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; +import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; +import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; +import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; +import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; +import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; +import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; +import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; +import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; +import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { toast } from 'features/toast/toast'; +import { useCallback } from 'react'; +import { serializeError } from 'serialize-error'; +import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; +import { assert, AssertionError } from 'tsafe'; + +const log = logger('generation'); +export const enqueueRequestedCanvas = createAction('app/enqueueRequestedCanvas'); + +const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prepend: boolean) => { + const { dispatch, getState } = store; + + dispatch(enqueueRequestedCanvas()); + + let destination = selectCanvasSessionId(getState()); + if (!destination) { + dispatch(canvasSessionIdCreated()); + destination = selectCanvasSessionId(getState()); + } + assert(destination !== null); + + const state = getState(); + + const model = state.params.model; + assert(model, 'No model found in state'); + const base = model.base; + + const buildGraphResult = await withResultAsync(async () => { + switch (base) { + case 'sdxl': + return await buildSDXLGraph(state, canvasManager); + case 'sd-1': + case `sd-2`: + return await buildSD1Graph(state, canvasManager); + case `sd-3`: + return await buildSD3Graph(state, canvasManager); + case `flux`: + return await buildFLUXGraph(state, canvasManager); + case 'cogview4': + return await buildCogView4Graph(state, canvasManager); + case 'imagen3': + return await buildImagen3Graph(state, canvasManager); + case 'imagen4': + return await buildImagen4Graph(state, canvasManager); + case 'chatgpt-4o': + return await buildChatGPT4oGraph(state, canvasManager); + case 'flux-kontext': + return await buildFluxKontextGraph(state, canvasManager); + default: + assert(false, `No graph builders for base ${base}`); + } + }); + + if (buildGraphResult.isErr()) { + let title = 'Failed to build graph'; + let status: AlertStatus = 'error'; + let description: string | null = null; + if (buildGraphResult.error instanceof AssertionError) { + description = extractMessageFromAssertionError(buildGraphResult.error); + } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { + title = 'Unsupported generation mode'; + description = buildGraphResult.error.message; + status = 'warning'; + } + const error = serializeError(buildGraphResult.error); + log.error({ error }, 'Failed to build graph'); + toast({ + status, + title, + description, + }); + return; + } + + const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => + prepareLinearUIBatch({ + state, + g, + prepend, + seedFieldIdentifier, + positivePromptFieldIdentifier, + origin: 'canvas', + destination, + }) + ); + + if (prepareBatchResult.isErr()) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + return; + } + + const batchConfig = prepareBatchResult.value; + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { + ...enqueueMutationFixedCacheKeyOptions, + track: false, + }) + ); + + const enqueueResult = await req.unwrap(); + + return { batchConfig, enqueueResult }; +}; + +export const useEnqueueCanvas = () => { + const store = useAppStore(); + const canvasManager = useCanvasManagerSafe(); + const enqueue = useCallback( + (prepend: boolean) => { + if (!canvasManager) { + log.error('Canvas manager is not available'); + return; + } + return enqueueCanvas(store, canvasManager, prepend); + }, + [canvasManager, store] + ); + return enqueue; +}; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts new file mode 100644 index 0000000000..9872861dc2 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts @@ -0,0 +1,137 @@ +import type { AlertStatus } from '@invoke-ai/ui-library'; +import { createAction } from '@reduxjs/toolkit'; +import { logger } from 'app/logging/logger'; +import type { AppStore } from 'app/store/store'; +import { useAppStore } from 'app/store/storeHooks'; +import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; +import { withResult, withResultAsync } from 'common/util/result'; +import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; +import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; +import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; +import { buildFLUXGraph } from 'features/nodes/util/graph/generation/buildFLUXGraph'; +import { buildFluxKontextGraph } from 'features/nodes/util/graph/generation/buildFluxKontextGraph'; +import { buildImagen3Graph } from 'features/nodes/util/graph/generation/buildImagen3Graph'; +import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildImagen4Graph'; +import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; +import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; +import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; +import { toast } from 'features/toast/toast'; +import { useCallback } from 'react'; +import { serializeError } from 'serialize-error'; +import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; +import { assert, AssertionError } from 'tsafe'; + +const log = logger('generation'); + +export const enqueueRequestedGenerate = createAction('app/enqueueRequestedGenerate'); + +const enqueueGenerate = async (store: AppStore, prepend: boolean) => { + const { dispatch, getState } = store; + + dispatch(enqueueRequestedGenerate()); + + let destination = selectGenerateSessionId(getState()); + if (!destination) { + dispatch(generateSessionIdCreated()); + destination = selectGenerateSessionId(getState()); + } + assert(destination !== null); + + const state = getState(); + + const model = state.params.model; + assert(model, 'No model found in state'); + const base = model.base; + + const buildGraphResult = await withResultAsync(async () => { + switch (base) { + case 'sdxl': + return await buildSDXLGraph(state, null); + case 'sd-1': + case `sd-2`: + return await buildSD1Graph(state, null); + case `sd-3`: + return await buildSD3Graph(state, null); + case `flux`: + return await buildFLUXGraph(state, null); + case 'cogview4': + return await buildCogView4Graph(state, null); + case 'imagen3': + return await buildImagen3Graph(state, null); + case 'imagen4': + return await buildImagen4Graph(state, null); + case 'chatgpt-4o': + return await buildChatGPT4oGraph(state, null); + case 'flux-kontext': + return await buildFluxKontextGraph(state, null); + default: + assert(false, `No graph builders for base ${base}`); + } + }); + + if (buildGraphResult.isErr()) { + let title = 'Failed to build graph'; + let status: AlertStatus = 'error'; + let description: string | null = null; + if (buildGraphResult.error instanceof AssertionError) { + description = extractMessageFromAssertionError(buildGraphResult.error); + } else if (buildGraphResult.error instanceof UnsupportedGenerationModeError) { + title = 'Unsupported generation mode'; + description = buildGraphResult.error.message; + status = 'warning'; + } + const error = serializeError(buildGraphResult.error); + log.error({ error }, 'Failed to build graph'); + toast({ + status, + title, + description, + }); + return; + } + + const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = buildGraphResult.value; + + const prepareBatchResult = withResult(() => + prepareLinearUIBatch({ + state, + g, + prepend, + seedFieldIdentifier, + positivePromptFieldIdentifier, + origin: 'canvas', + destination, + }) + ); + + if (prepareBatchResult.isErr()) { + log.error({ error: serializeError(prepareBatchResult.error) }, 'Failed to prepare batch'); + return; + } + + const batchConfig = prepareBatchResult.value; + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { + ...enqueueMutationFixedCacheKeyOptions, + track: false, + }) + ); + + const enqueueResult = await req.unwrap(); + + return { batchConfig, enqueueResult }; +}; + +export const useEnqueueGenerate = () => { + const store = useAppStore(); + const enqueue = useCallback( + (prepend: boolean) => { + return enqueueGenerate(store, prepend); + }, + [store] + ); + return enqueue; +}; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts new file mode 100644 index 0000000000..71ec1b71a6 --- /dev/null +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueUpscaling.ts @@ -0,0 +1,47 @@ +import { createAction } from '@reduxjs/toolkit'; +import type { AppStore } from 'app/store/store'; +import { useAppStore } from 'app/store/storeHooks'; +import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; +import { buildMultidiffusionUpscaleGraph } from 'features/nodes/util/graph/buildMultidiffusionUpscaleGraph'; +import { useCallback } from 'react'; +import { enqueueMutationFixedCacheKeyOptions, queueApi } from 'services/api/endpoints/queue'; + +const enqueueRequestedUpscaling = createAction('app/enqueueRequestedUpscaling'); + +const enqueueUpscaling = async (store: AppStore, prepend: boolean) => { + const { dispatch, getState } = store; + + dispatch(enqueueRequestedUpscaling()); + + const state = getState(); + + const { g, seedFieldIdentifier, positivePromptFieldIdentifier } = await buildMultidiffusionUpscaleGraph(state); + + const batchConfig = prepareLinearUIBatch({ + state, + g, + prepend, + seedFieldIdentifier, + positivePromptFieldIdentifier, + origin: 'upscaling', + destination: 'gallery', + }); + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false }) + ); + const enqueueResult = await req.unwrap(); + + return { batchConfig, enqueueResult }; +}; + +export const useEnqueueUpscaling = () => { + const store = useAppStore(); + const enqueue = useCallback( + (prepend: boolean) => { + return enqueueUpscaling(store, prepend); + }, + [store] + ); + return enqueue; +}; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts index 4460fe553c..0e245cf581 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueWorkflows.ts @@ -1,4 +1,5 @@ import { createAction } from '@reduxjs/toolkit'; +import type { AppDispatch, AppStore, RootState } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { groupBy } from 'es-toolkit/compat'; import { @@ -8,6 +9,7 @@ import { } from 'features/nodes/components/sidePanel/workflow/publish'; import { $templates } from 'features/nodes/store/nodesSlice'; import { selectNodeData, selectNodesSlice } from 'features/nodes/store/selectors'; +import type { Templates } from 'features/nodes/store/types'; import { isBatchNode, isInvocationNode } from 'features/nodes/types/invocation'; import { buildNodesGraph } from 'features/nodes/util/graph/buildNodesGraph'; import { resolveBatchValue } from 'features/nodes/util/node/resolveBatchValue'; @@ -19,138 +21,160 @@ import { assert } from 'tsafe'; const enqueueRequestedWorkflows = createAction('app/enqueueRequestedWorkflows'); -export const useEnqueueWorkflows = () => { - const { getState, dispatch } = useAppStore(); - const enqueue = useCallback( - async (prepend: boolean, isApiValidationRun: boolean) => { - dispatch(enqueueRequestedWorkflows()); - const state = getState(); - const nodesState = selectNodesSlice(state); - const templates = $templates.get(); - const graph = buildNodesGraph(state, templates); - const builtWorkflow = buildWorkflowWithValidation(nodesState); +const getBatchDataForWorkflowGeneration = async (state: RootState, dispatch: AppDispatch): Promise => { + const nodesState = selectNodesSlice(state); + const data: Batch['data'] = []; - if (builtWorkflow) { - // embedded workflows don't have an id - delete builtWorkflow.id; - } + const invocationNodes = nodesState.nodes.filter(isInvocationNode); + const batchNodes = invocationNodes.filter(isBatchNode); - const data: Batch['data'] = []; + // Handle zipping batch nodes. First group the batch nodes by their batch_group_id + const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value); - const invocationNodes = nodesState.nodes.filter(isInvocationNode); - const batchNodes = invocationNodes.filter(isBatchNode); + // Then, we will create a batch data collection item for each group + for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) { + const zippedBatchDataCollectionItems: NonNullable[number] = []; - // Handle zipping batch nodes. First group the batch nodes by their batch_group_id - const groupedBatchNodes = groupBy(batchNodes, (node) => node.data.inputs['batch_group_id']?.value); - - // Then, we will create a batch data collection item for each group - for (const [batchGroupId, batchNodes] of Object.entries(groupedBatchNodes)) { - const zippedBatchDataCollectionItems: NonNullable[number] = []; - - for (const node of batchNodes) { - const value = await resolveBatchValue({ nodesState, node, dispatch }); - const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value'; - const edgesFromBatch = nodesState.edges.filter( - (e) => e.source === node.id && e.sourceHandle === sourceHandle - ); - if (batchGroupId !== 'None') { - // If this batch node has a batch_group_id, we will zip the data collection items - for (const edge of edgesFromBatch) { - if (!edge.targetHandle) { - break; - } - zippedBatchDataCollectionItems.push({ - node_path: edge.target, - field_name: edge.targetHandle, - items: value, - }); - } - } else { - // Otherwise add the data collection items to root of the batch so they are not zipped - const productBatchDataCollectionItems: NonNullable[number] = []; - for (const edge of edgesFromBatch) { - if (!edge.targetHandle) { - break; - } - productBatchDataCollectionItems.push({ - node_path: edge.target, - field_name: edge.targetHandle, - items: value, - }); - } - if (productBatchDataCollectionItems.length > 0) { - data.push(productBatchDataCollectionItems); - } + for (const node of batchNodes) { + const value = await resolveBatchValue({ nodesState, node, dispatch }); + const sourceHandle = node.data.type === 'image_batch' ? 'image' : 'value'; + const edgesFromBatch = nodesState.edges.filter((e) => e.source === node.id && e.sourceHandle === sourceHandle); + if (batchGroupId !== 'None') { + // If this batch node has a batch_group_id, we will zip the data collection items + for (const edge of edgesFromBatch) { + if (!edge.targetHandle) { + break; } + zippedBatchDataCollectionItems.push({ + node_path: edge.target, + field_name: edge.targetHandle, + items: value, + }); } - - // Finally, if this batch data collection item has any items, add it to the data array - if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) { - data.push(zippedBatchDataCollectionItems); + } else { + // Otherwise add the data collection items to root of the batch so they are not zipped + const productBatchDataCollectionItems: NonNullable[number] = []; + for (const edge of edgesFromBatch) { + if (!edge.targetHandle) { + break; + } + productBatchDataCollectionItems.push({ + node_path: edge.target, + field_name: edge.targetHandle, + items: value, + }); + } + if (productBatchDataCollectionItems.length > 0) { + data.push(productBatchDataCollectionItems); } } + } - const batchConfig: EnqueueBatchArg = { - batch: { - graph, - workflow: builtWorkflow, - runs: state.params.iterations, - origin: 'workflows', - destination: 'gallery', - data, - }, - prepend, - }; + // Finally, if this batch data collection item has any items, add it to the data array + if (batchGroupId !== 'None' && zippedBatchDataCollectionItems.length > 0) { + data.push(zippedBatchDataCollectionItems); + } + } - if (isApiValidationRun) { - // Derive the input fields from the builder's selected node field elements - const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state); - const inputs = getPublishInputs(fieldIdentifiers, templates); - const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => { - return { - kind: 'input', - node_id: nodeId, - field_name: fieldName, - user_label: label, - } satisfies S['FieldIdentifier']; - }); + return data; +}; - // Derive the output fields from the builder's selected output node - const outputNodeId = $outputNodeId.get(); - assert(outputNodeId !== null, 'Output node not selected'); - const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type; - const outputNodeTemplate = templates[outputNodeType]; - assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`); - const outputFieldNames = Object.keys(outputNodeTemplate.outputs); - const api_output_fields = outputFieldNames.map((fieldName) => { - return { - kind: 'output', - node_id: outputNodeId, - field_name: fieldName, - user_label: null, - } satisfies S['FieldIdentifier']; - }); +const getValidationRunData = (state: RootState, templates: Templates): S['ValidationRunData'] => { + const nodesState = selectNodesSlice(state); - assert(nodesState.id, 'Workflow without ID cannot be used for API validation run'); + // Derive the input fields from the builder's selected node field elements + const fieldIdentifiers = selectFieldIdentifiersWithInvocationTypes(state); + const inputs = getPublishInputs(fieldIdentifiers, templates); + const api_input_fields = inputs.publishable.map(({ nodeId, fieldName, label }) => { + return { + kind: 'input', + node_id: nodeId, + field_name: fieldName, + user_label: label, + } satisfies S['FieldIdentifier']; + }); - batchConfig.validation_run_data = { - workflow_id: nodesState.id, - input_fields: api_input_fields, - output_fields: api_output_fields, - }; + // Derive the output fields from the builder's selected output node + const outputNodeId = $outputNodeId.get(); + assert(outputNodeId !== null, 'Output node not selected'); + const outputNodeType = selectNodeData(selectNodesSlice(state), outputNodeId).type; + const outputNodeTemplate = templates[outputNodeType]; + assert(outputNodeTemplate, `Template for node type ${outputNodeType} not found`); + const outputFieldNames = Object.keys(outputNodeTemplate.outputs); + const api_output_fields = outputFieldNames.map((fieldName) => { + return { + kind: 'output', + node_id: outputNodeId, + field_name: fieldName, + user_label: null, + } satisfies S['FieldIdentifier']; + }); - // If the batch is an API validation run, we only want to run it once - batchConfig.batch.runs = 1; - } + assert(nodesState.id, 'Workflow without ID cannot be used for API validation run'); - const req = dispatch( - queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false }) - ); + return { + workflow_id: nodesState.id, + input_fields: api_input_fields, + output_fields: api_output_fields, + }; +}; - const enqueueResult = await req.unwrap(); - return { batchConfig, enqueueResult }; +const enqueueWorkflows = async ( + store: AppStore, + templates: Templates, + prepend: boolean, + isApiValidationRun: boolean +) => { + const { dispatch, getState } = store; + + dispatch(enqueueRequestedWorkflows()); + const state = getState(); + const nodesState = selectNodesSlice(state); + const graph = buildNodesGraph(state, templates); + const workflow = buildWorkflowWithValidation(nodesState); + + if (workflow) { + // embedded workflows don't have an id + delete workflow.id; + } + + const runs = state.params.iterations; + const data = await getBatchDataForWorkflowGeneration(state, dispatch); + + const batchConfig: EnqueueBatchArg = { + batch: { + graph, + workflow, + runs, + origin: 'workflows', + destination: 'gallery', + data, }, - [dispatch, getState] + prepend, + }; + + if (isApiValidationRun) { + batchConfig.validation_run_data = getValidationRunData(state, templates); + + // If the batch is an API validation run, we only want to run it once + batchConfig.batch.runs = 1; + } + + const req = dispatch( + queueApi.endpoints.enqueueBatch.initiate(batchConfig, { ...enqueueMutationFixedCacheKeyOptions, track: false }) + ); + + const enqueueResult = await req.unwrap(); + return { batchConfig, enqueueResult }; +}; + +export const useEnqueueWorkflows = () => { + const store = useAppStore(); + const enqueue = useCallback( + (prepend: boolean, isApiValidationRun: boolean) => { + return enqueueWorkflows(store, $templates.get(), prepend, isApiValidationRun); + }, + [store] ); return enqueue; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useInvoke.ts b/invokeai/frontend/web/src/features/queue/hooks/useInvoke.ts index 211bb67cc7..d9f50916b8 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useInvoke.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useInvoke.ts @@ -1,10 +1,7 @@ import { useStore } from '@nanostores/react'; import { logger } from 'app/logging/logger'; -import { enqueueRequestedCanvas } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedLinear'; -import { enqueueRequestedUpscaling } from 'app/store/middleware/listenerMiddleware/listeners/enqueueRequestedUpscale'; -import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; +import { useAppSelector } from 'app/store/storeHooks'; import { withResultAsync } from 'common/util/result'; -import { parseify } from 'common/util/serialize'; import { useIsWorkflowEditorLocked } from 'features/nodes/hooks/useIsWorkflowEditorLocked'; import { useEnqueueWorkflows } from 'features/queue/hooks/useEnqueueWorkflows'; import { $isReadyToEnqueue } from 'features/queue/store/readiness'; @@ -15,15 +12,21 @@ import { useCallback } from 'react'; import { serializeError } from 'serialize-error'; import { enqueueMutationFixedCacheKeyOptions, useEnqueueBatchMutation } from 'services/api/endpoints/queue'; +import { useEnqueueCanvas } from './useEnqueueCanvas'; +import { useEnqueueGenerate } from './useEnqueueGenerate'; +import { useEnqueueUpscaling } from './useEnqueueUpscaling'; + const log = logger('generation'); export const useInvoke = () => { - const dispatch = useAppDispatch(); const ctx = useAutoLayoutContextSafe(); const tabName = useAppSelector(selectActiveTab); const isReady = useStore($isReadyToEnqueue); const isLocked = useIsWorkflowEditorLocked(); const enqueueWorkflows = useEnqueueWorkflows(); + const enqueueCanvas = useEnqueueCanvas(); + const enqueueGenerate = useEnqueueGenerate(); + const enqueueUpscaling = useEnqueueUpscaling(); const [_, { isLoading }] = useEnqueueBatchMutation(enqueueMutationFixedCacheKeyOptions); @@ -33,28 +36,26 @@ export const useInvoke = () => { return; } - if (tabName === 'workflows') { - const result = await withResultAsync(() => enqueueWorkflows(prepend, isApiValidationRun)); - if (result.isErr()) { - log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch'); - } else { - log.debug(parseify(result.value), 'Enqueued batch'); + const result = await withResultAsync(async () => { + switch (tabName) { + case 'workflows': + return await enqueueWorkflows(prepend, isApiValidationRun); + case 'canvas': + return await enqueueCanvas(prepend); + case 'generate': + return await enqueueGenerate(prepend); + case 'upscaling': + return await enqueueUpscaling(prepend); + default: + throw new Error(`No enqueue handler for tab: ${tabName}`); } - } + }); - if (tabName === 'upscaling') { - dispatch(enqueueRequestedUpscaling({ prepend })); - return; + if (result.isErr()) { + log.error({ error: serializeError(result.error) }, 'Failed to enqueue batch'); } - - if (tabName === 'canvas' || tabName === 'generate') { - dispatch(enqueueRequestedCanvas({ prepend })); - return; - } - - // Else we are not on a generation tab and should not queue }, - [dispatch, enqueueWorkflows, isReady, tabName] + [enqueueCanvas, enqueueGenerate, enqueueUpscaling, enqueueWorkflows, isReady, tabName] ); const enqueueBack = useCallback(() => {