From a035645ed3daf52acd85c215d142dbb344cea0ba Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 30 Jun 2025 16:34:33 +1000 Subject: [PATCH] refactor(ui): graph building respects selected tab --- .../store/canvasStagingAreaSlice.ts | 25 ++----- .../graph/generation/buildChatGPT4oGraph.ts | 20 ++---- .../graph/generation/buildCogView4Graph.ts | 23 ++---- .../util/graph/generation/buildFLUXGraph.ts | 40 ++++------- .../graph/generation/buildFluxKontextGraph.ts | 63 +++++++---------- .../graph/generation/buildImagen3Graph.ts | 69 ++++++++---------- .../graph/generation/buildImagen4Graph.ts | 70 ++++++++----------- .../util/graph/generation/buildSD1Graph.ts | 30 +++----- .../util/graph/generation/buildSD3Graph.ts | 20 ++---- .../util/graph/generation/buildSDXLGraph.ts | 30 +++----- .../nodes/util/graph/graphBuilderUtils.ts | 19 ++++- .../src/features/nodes/util/graph/types.ts | 14 ++++ .../features/queue/hooks/useEnqueueCanvas.ts | 44 ++++++------ .../queue/hooks/useEnqueueGenerate.ts | 43 ++++++------ 14 files changed, 223 insertions(+), 287 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts b/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts index cb5e30aa4e..8c0da9c86f 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/canvasStagingAreaSlice.ts @@ -1,7 +1,6 @@ import { createSelector, createSlice, type PayloadAction } from '@reduxjs/toolkit'; import type { PersistConfig, RootState } from 'app/store/store'; import { deepClone } from 'common/util/deepClone'; -import { getPrefixedId } from 'features/controlLayers/konva/util'; import { canvasReset } from 'features/controlLayers/store/actions'; type CanvasStagingAreaState = { @@ -20,26 +19,16 @@ export const canvasSessionSlice = createSlice({ name: 'canvasSession', initialState: getInitialState(), reducers: { - generateSessionIdCreated: { - reducer: (state, action: PayloadAction<{ id: string }>) => { - const { id } = action.payload; - state.generateSessionId = id; - }, - prepare: () => ({ - payload: { id: getPrefixedId('generate') }, - }), + generateSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.generateSessionId = id; }, generateSessionReset: (state) => { state.generateSessionId = null; }, - canvasSessionIdCreated: { - reducer: (state, action: PayloadAction<{ id: string }>) => { - const { id } = action.payload; - state.canvasSessionId = id; - }, - prepare: () => ({ - payload: { id: getPrefixedId('canvas') }, - }), + canvasSessionIdChanged: (state, action: PayloadAction<{ id: string }>) => { + const { id } = action.payload; + state.canvasSessionId = id; }, canvasSessionReset: (state) => { state.canvasSessionId = null; @@ -52,7 +41,7 @@ export const canvasSessionSlice = createSlice({ }, }); -export const { generateSessionIdCreated, generateSessionReset, canvasSessionIdCreated, canvasSessionReset } = +export const { generateSessionIdChanged, generateSessionReset, canvasSessionIdChanged, canvasSessionReset } = canvasSessionSlice.actions; /* eslint-disable-next-line @typescript-eslint/no-explicit-any */ diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildChatGPT4oGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildChatGPT4oGraph.ts index 21ad0fe8e4..eb64c4b19b 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildChatGPT4oGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildChatGPT4oGraph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; @@ -8,23 +6,18 @@ import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isChatGPT4oAspectRatioID, isChatGPT4oReferenceImageConfig } from 'features/controlLayers/store/types'; import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; import { type ImageField, zModelIdentifierField } from 'features/nodes/types/common'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; -import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { t } from 'i18next'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; const log = logger('system'); -export const buildChatGPT4oGraph = async ( - state: RootState, - manager: CanvasManager | null -): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildChatGPT4oGraph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; if (generationMode !== 'txt2img' && generationMode !== 'img2img') { throw new UnsupportedGenerationModeError(t('toast.chatGPT4oIncompatibleGenerationMode')); @@ -86,9 +79,8 @@ export const buildChatGPT4oGraph = async ( } if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); - const adapters = manager.compositor.getVisibleAdaptersOfType('raster_layer'); - const { image_name } = await manager.compositor.getCompositeImageDTO(adapters, bbox.rect, { + const adapters = arg.canvasManager.compositor.getVisibleAdaptersOfType('raster_layer'); + const { image_name } = await arg.canvasManager.compositor.getCompositeImageDTO(adapters, bbox.rect, { is_intermediate: true, silent: true, }); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts index d2b27dd7cc..297e09b551 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildCogView4Graph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; @@ -11,15 +9,13 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getSizes, selectCanvasOutputFields, selectPresetModifiedPrompts, } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import type { Equals } from 'tsafe'; @@ -27,12 +23,8 @@ import { assert } from 'tsafe'; const log = logger('system'); -export const buildCogView4Graph = async ( - state: RootState, - manager: CanvasManager | null -): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildCogView4Graph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; log.debug({ generationMode }, 'Building CogView4 graph'); const params = selectParamsSlice(state); @@ -112,10 +104,9 @@ export const buildCogView4Graph = async ( canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize }); g.upsertMetadata({ generation_mode: 'cogview4_txt2img' }); } else if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); canvasOutput = await addImageToImage({ g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'cogview4_i2l', denoise, @@ -128,11 +119,10 @@ export const buildCogView4Graph = async ( }); g.upsertMetadata({ generation_mode: 'cogview4_img2img' }); } else if (generationMode === 'inpaint') { - assert(manager, 'Need manager to do inpaint'); canvasOutput = await addInpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'cogview4_i2l', denoise, @@ -146,11 +136,10 @@ export const buildCogView4Graph = async ( }); g.upsertMetadata({ generation_mode: 'cogview4_inpaint' }); } else if (generationMode === 'outpaint') { - assert(manager, 'Need manager to do outpaint'); canvasOutput = await addOutpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'cogview4_i2l', denoise, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 2a012bd85a..5ccff3f4d8 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; @@ -15,19 +13,14 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; import { addRegions } from 'features/nodes/util/graph/generation/addRegions'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getSizes, selectCanvasOutputFields, selectPresetModifiedPrompts, } from 'features/nodes/util/graph/graphBuilderUtils'; -import { - type GraphBuilderReturn, - type ImageOutputNodes, - UnsupportedGenerationModeError, -} from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { t } from 'i18next'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; @@ -38,9 +31,8 @@ import { addIPAdapters } from './addIPAdapters'; const log = logger('system'); -export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | null): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildFLUXGraph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; log.debug({ generationMode }, 'Building FLUX graph'); const params = selectParamsSlice(state); @@ -171,12 +163,11 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | let canvasOutput: Invocation = l2i; - if (isFLUXFill) { - assert(manager, 'Need manager to do FLUX Fill'); + if (isFLUXFill && (generationMode === 'inpaint' || generationMode === 'outpaint')) { canvasOutput = await addFLUXFill({ state, g, - manager, + manager: arg.canvasManager, l2i, denoise, originalSize, @@ -186,10 +177,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize }); g.upsertMetadata({ generation_mode: 'flux_txt2img' }); } else if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); canvasOutput = await addImageToImage({ g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'flux_vae_encode', denoise, @@ -202,11 +192,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | }); g.upsertMetadata({ generation_mode: 'flux_img2img' }); } else if (generationMode === 'inpaint') { - assert(manager, 'Need manager to do inpaint'); canvasOutput = await addInpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'flux_vae_encode', denoise, @@ -220,11 +209,10 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | }); g.upsertMetadata({ generation_mode: 'flux_inpaint' }); } else if (generationMode === 'outpaint') { - assert(manager, 'Need manager to do outpaint'); canvasOutput = await addOutpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'flux_vae_encode', denoise, @@ -241,13 +229,13 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | assert>(false); } - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const controlNetCollector = g.addNode({ type: 'collect', id: getPrefixedId('control_net_collector'), }); const controlNetResult = await addControlNets({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -261,7 +249,7 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | } await addControlLoRA({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -295,9 +283,9 @@ export const buildFLUXGraph = async (state: RootState, manager: CanvasManager | }); let totalReduxesAdded = fluxReduxResult.addedFLUXReduxes; - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const regionsResult = await addRegions({ - manager, + manager: arg.canvasManager, regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFluxKontextGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFluxKontextGraph.ts index 70ea82b346..e8a55a7e8d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFluxKontextGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFluxKontextGraph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; @@ -11,22 +9,15 @@ import type { ImageField } from 'features/nodes/types/common'; import { zModelIdentifierField } from 'features/nodes/types/common'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; -import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { t } from 'i18next'; -import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; -import { getGenerationMode } from './getGenerationMode'; - const log = logger('system'); -export const buildFluxKontextGraph = async ( - state: RootState, - manager: CanvasManager | null -): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildFluxKontextGraph = (arg: GraphBuilderArg): GraphBuilderReturn => { + const { generationMode, state } = arg; if (generationMode !== 'txt2img') { throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'FLUX Kontext' })); @@ -61,29 +52,25 @@ export const buildFluxKontextGraph = async ( }; } - if (generationMode === 'txt2img') { - const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph')); - const fluxKontextImage = g.addNode({ - // @ts-expect-error: These nodes are not available in the OSS application - type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image', - model: zModelIdentifierField.parse(model), - positive_prompt: positivePrompt, - aspect_ratio: bbox.aspectRatio.id, - input_image, - prompt_upsampling: true, - ...selectCanvasOutputFields(state), - }); - g.upsertMetadata({ - positive_prompt: positivePrompt, - model: Graph.getModelMetadataField(model), - width: bbox.rect.width, - height: bbox.rect.height, - }); - return { - g, - positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' }, - }; - } - - assert>(false, 'Invalid generation mode for Flux Kontext'); + const g = new Graph(getPrefixedId('flux_kontext_txt2img_graph')); + const fluxKontextImage = g.addNode({ + // @ts-expect-error: These nodes are not available in the OSS application + type: input_image ? 'flux_kontext_edit_image' : 'flux_kontext_generate_image', + model: zModelIdentifierField.parse(model), + positive_prompt: positivePrompt, + aspect_ratio: bbox.aspectRatio.id, + input_image, + prompt_upsampling: true, + ...selectCanvasOutputFields(state), + }); + g.upsertMetadata({ + positive_prompt: positivePrompt, + model: Graph.getModelMetadataField(model), + width: bbox.rect.width, + height: bbox.rect.height, + }); + return { + g, + positivePromptFieldIdentifier: { nodeId: fluxKontextImage.id, fieldName: 'positive_prompt' }, + }; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen3Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen3Graph.ts index 15c478798f..95daa1118d 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen3Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen3Graph.ts @@ -1,28 +1,20 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isImagenAspectRatioID } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; -import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { t } from 'i18next'; -import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; const log = logger('system'); -export const buildImagen3Graph = async ( - state: RootState, - manager: CanvasManager | null -): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildImagen3Graph = (arg: GraphBuilderArg): GraphBuilderReturn => { + const { generationMode, state } = arg; if (generationMode !== 'txt2img') { throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen3' })); @@ -41,33 +33,30 @@ export const buildImagen3Graph = async ( assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen3 does not support this aspect ratio'); assert(positivePrompt.length > 0, 'Imagen3 requires positive prompt to have at least one character'); - if (generationMode === 'txt2img') { - const g = new Graph(getPrefixedId('imagen3_txt2img_graph')); - const imagen3 = g.addNode({ - // @ts-expect-error: These nodes are not available in the OSS application - type: 'google_imagen3_generate_image', - model: zModelIdentifierField.parse(model), - positive_prompt: positivePrompt, - negative_prompt: negativePrompt, - aspect_ratio: bbox.aspectRatio.id, - // When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed. - enhance_prompt: true, - ...selectCanvasOutputFields(state), - }); - g.upsertMetadata({ - positive_prompt: positivePrompt, - negative_prompt: negativePrompt, - width: bbox.rect.width, - height: bbox.rect.height, - model: Graph.getModelMetadataField(model), - ...selectCanvasMetadata(state), - }); - return { - g, - seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' }, - positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' }, - }; - } + const g = new Graph(getPrefixedId('imagen3_txt2img_graph')); + const imagen3 = g.addNode({ + // @ts-expect-error: These nodes are not available in the OSS application + type: 'google_imagen3_generate_image', + model: zModelIdentifierField.parse(model), + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + aspect_ratio: bbox.aspectRatio.id, + // When enhance_prompt is true, Imagen3 will return a new image every time, ignoring the seed. + enhance_prompt: true, + ...selectCanvasOutputFields(state), + }); + g.upsertMetadata({ + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + width: bbox.rect.width, + height: bbox.rect.height, + model: Graph.getModelMetadataField(model), + ...selectCanvasMetadata(state), + }); - assert>(false, 'Invalid generation mode for Imagen3'); + return { + g, + seedFieldIdentifier: { nodeId: imagen3.id, fieldName: 'seed' }, + positivePromptFieldIdentifier: { nodeId: imagen3.id, fieldName: 'positive_prompt' }, + }; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen4Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen4Graph.ts index b83c67f21c..c97faf29b4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen4Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildImagen4Graph.ts @@ -1,28 +1,20 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { isImagenAspectRatioID } from 'features/controlLayers/store/types'; import { zModelIdentifierField } from 'features/nodes/types/common'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { selectCanvasOutputFields, selectPresetModifiedPrompts } from 'features/nodes/util/graph/graphBuilderUtils'; -import { type GraphBuilderReturn, UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn } from 'features/nodes/util/graph/types'; +import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { t } from 'i18next'; -import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; const log = logger('system'); -export const buildImagen4Graph = async ( - state: RootState, - manager: CanvasManager | null -): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildImagen4Graph = (arg: GraphBuilderArg): GraphBuilderReturn => { + const { generationMode, state } = arg; if (generationMode !== 'txt2img') { throw new UnsupportedGenerationModeError(t('toast.imagenIncompatibleGenerationMode', { model: 'Imagen4' })); @@ -41,34 +33,30 @@ export const buildImagen4Graph = async ( assert(isImagenAspectRatioID(bbox.aspectRatio.id), 'Imagen4 does not support this aspect ratio'); assert(positivePrompt.length > 0, 'Imagen4 requires positive prompt to have at least one character'); - if (generationMode === 'txt2img') { - const g = new Graph(getPrefixedId('imagen4_txt2img_graph')); - const imagen4 = g.addNode({ - // @ts-expect-error: These nodes are not available in the OSS application - type: 'google_imagen4_generate_image', - model: zModelIdentifierField.parse(model), - positive_prompt: positivePrompt, - negative_prompt: negativePrompt, - aspect_ratio: bbox.aspectRatio.id, - // When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed. - enhance_prompt: true, - ...selectCanvasOutputFields(state), - }); - g.upsertMetadata({ - positive_prompt: positivePrompt, - negative_prompt: negativePrompt, - width: bbox.rect.width, - height: bbox.rect.height, - model: Graph.getModelMetadataField(model), - ...selectCanvasMetadata(state), - }); + const g = new Graph(getPrefixedId('imagen4_txt2img_graph')); + const imagen4 = g.addNode({ + // @ts-expect-error: These nodes are not available in the OSS application + type: 'google_imagen4_generate_image', + model: zModelIdentifierField.parse(model), + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + aspect_ratio: bbox.aspectRatio.id, + // When enhance_prompt is true, Imagen4 will return a new image every time, ignoring the seed. + enhance_prompt: true, + ...selectCanvasOutputFields(state), + }); + g.upsertMetadata({ + positive_prompt: positivePrompt, + negative_prompt: negativePrompt, + width: bbox.rect.width, + height: bbox.rect.height, + model: Graph.getModelMetadataField(model), + ...selectCanvasMetadata(state), + }); - return { - g, - seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' }, - positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' }, - }; - } - - assert>(false, 'Invalid generation mode for Imagen4'); + return { + g, + seedFieldIdentifier: { nodeId: imagen4.id, fieldName: 'seed' }, + positivePromptFieldIdentifier: { nodeId: imagen4.id, fieldName: 'positive_prompt' }, + }; }; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index af5c0d5601..6d5113f1b5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; @@ -16,15 +14,13 @@ import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getSizes, selectCanvasOutputFields, selectPresetModifiedPrompts, } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -33,9 +29,8 @@ import { addRegions } from './addRegions'; const log = logger('system'); -export const buildSD1Graph = async (state: RootState, manager: CanvasManager | null): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildSD1Graph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; log.debug({ generationMode }, 'Building SD1/SD2 graph'); const params = selectParamsSlice(state); @@ -171,10 +166,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize }); g.upsertMetadata({ generation_mode: 'txt2img' }); } else if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); canvasOutput = await addImageToImage({ g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -187,11 +181,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n }); g.upsertMetadata({ generation_mode: 'img2img' }); } else if (generationMode === 'inpaint') { - assert(manager, 'Need manager to do inpaint'); canvasOutput = await addInpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -205,11 +198,10 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n }); g.upsertMetadata({ generation_mode: 'inpaint' }); } else if (generationMode === 'outpaint') { - assert(manager, 'Need manager to do outpaint'); canvasOutput = await addOutpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -226,13 +218,13 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n assert>(false); } - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const controlNetCollector = g.addNode({ type: 'collect', id: getPrefixedId('control_net_collector'), }); const controlNetResult = await addControlNets({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -250,7 +242,7 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n id: getPrefixedId('t2i_adapter_collector'), }); const t2iAdapterResult = await addT2IAdapters({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -276,9 +268,9 @@ export const buildSD1Graph = async (state: RootState, manager: CanvasManager | n }); let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters; - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const regionsResult = await addRegions({ - manager, + manager: arg.canvasManager, regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD3Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD3Graph.ts index a1539f7c9f..c56e411fd7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD3Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD3Graph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; @@ -10,24 +8,21 @@ import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChec import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getSizes, selectCanvasOutputFields, selectPresetModifiedPrompts, } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; const log = logger('system'); -export const buildSD3Graph = async (state: RootState, manager: CanvasManager | null): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildSD3Graph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; log.debug({ generationMode }, 'Building SD3 graph'); const model = selectMainModelConfig(state); @@ -133,10 +128,9 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize }); g.upsertMetadata({ generation_mode: 'sd3_txt2img' }); } else if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); canvasOutput = await addImageToImage({ g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'sd3_i2l', denoise, @@ -149,11 +143,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n }); g.upsertMetadata({ generation_mode: 'sd3_img2img' }); } else if (generationMode === 'inpaint') { - assert(manager, 'Need manager to do inpaint'); canvasOutput = await addInpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'sd3_i2l', denoise, @@ -167,11 +160,10 @@ export const buildSD3Graph = async (state: RootState, manager: CanvasManager | n }); g.upsertMetadata({ generation_mode: 'sd3_inpaint' }); } else if (generationMode === 'outpaint') { - assert(manager, 'Need manager to do outpaint'); canvasOutput = await addOutpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'sd3_i2l', denoise, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 13534f61e1..5b7b542ee0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -1,6 +1,4 @@ import { logger } from 'app/logging/logger'; -import type { RootState } from 'app/store/store'; -import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectMainModelConfig, selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectRefImagesSlice } from 'features/controlLayers/store/refImagesSlice'; @@ -16,15 +14,13 @@ import { addSDXLRefiner } from 'features/nodes/util/graph/generation/addSDXLRefi import { addSeamless } from 'features/nodes/util/graph/generation/addSeamless'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; -import { getGenerationMode } from 'features/nodes/util/graph/generation/getGenerationMode'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; import { getSizes, selectCanvasOutputFields, selectPresetModifiedPrompts, } from 'features/nodes/util/graph/graphBuilderUtils'; -import type { GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; -import { selectActiveTab } from 'features/ui/store/uiSelectors'; +import type { GraphBuilderArg, GraphBuilderReturn, ImageOutputNodes } from 'features/nodes/util/graph/types'; import type { Invocation } from 'services/api/types'; import type { Equals } from 'tsafe'; import { assert } from 'tsafe'; @@ -33,9 +29,8 @@ import { addRegions } from './addRegions'; const log = logger('system'); -export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | null): Promise => { - const tab = selectActiveTab(state); - const generationMode = await getGenerationMode(manager, tab); +export const buildSDXLGraph = async (arg: GraphBuilderArg): Promise => { + const { generationMode, state } = arg; log.debug({ generationMode }, 'Building SDXL graph'); const model = selectMainModelConfig(state); @@ -178,10 +173,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | canvasOutput = addTextToImage({ g, l2i, originalSize, scaledSize }); g.upsertMetadata({ generation_mode: 'sdxl_txt2img' }); } else if (generationMode === 'img2img') { - assert(manager, 'Need manager to do img2img'); canvasOutput = await addImageToImage({ g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -194,11 +188,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | }); g.upsertMetadata({ generation_mode: 'sdxl_img2img' }); } else if (generationMode === 'inpaint') { - assert(manager, 'Need manager to do inpaint'); canvasOutput = await addInpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -212,11 +205,10 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | }); g.upsertMetadata({ generation_mode: 'sdxl_inpaint' }); } else if (generationMode === 'outpaint') { - assert(manager, 'Need manager to do outpaint'); canvasOutput = await addOutpaint({ state, g, - manager, + manager: arg.canvasManager, l2i, i2lNodeType: 'i2l', denoise, @@ -233,13 +225,13 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | assert>(false); } - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const controlNetCollector = g.addNode({ type: 'collect', id: getPrefixedId('control_net_collector'), }); const controlNetResult = await addControlNets({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -257,7 +249,7 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | id: getPrefixedId('t2i_adapter_collector'), }); const t2iAdapterResult = await addT2IAdapters({ - manager, + manager: arg.canvasManager, entities: canvas.controlLayers.entities, g, rect: canvas.bbox.rect, @@ -283,9 +275,9 @@ export const buildSDXLGraph = async (state: RootState, manager: CanvasManager | }); let totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters; - if (manager) { + if (generationMode === 'img2img' || generationMode === 'inpaint' || generationMode === 'outpaint') { const regionsResult = await addRegions({ - manager, + manager: arg.canvasManager, regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts index 8e41a97608..db8612958f 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/graphBuilderUtils.ts @@ -1,6 +1,7 @@ import { createSelector } from '@reduxjs/toolkit'; import type { RootState } from 'app/store/store'; import { pick } from 'es-toolkit/compat'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import type { CanvasState, ParamsState } from 'features/controlLayers/store/types'; @@ -13,7 +14,7 @@ import { selectListStylePresetsRequestState } from 'services/api/endpoints/style import type { Invocation, S } from 'services/api/types'; import { assert } from 'tsafe'; -import type { MainModelLoaderNodes } from './types'; +import type { GraphBuilderArg, MainModelLoaderNodes } from './types'; /** * Gets the board field, based on the autoAddBoardId setting. @@ -165,3 +166,19 @@ export const isCanvasOutputNodeId = (nodeId: string) => nodeId.split(':')[0] === export const isCanvasOutputEvent = (data: S['InvocationCompleteEvent']) => { return isCanvasOutputNodeId(data.invocation_source_id); }; + +export const getGraphBuilderArg = async ( + state: RootState, + canvasManager: CanvasManager | null +): Promise => { + const tab = selectActiveTab(state); + if (tab === 'generate') { + return { generationMode: 'txt2img', state }; + } else if (tab === 'canvas') { + assert(canvasManager !== null, 'CanvasManager should not be null in canvas tab'); + const generationMode = await canvasManager.compositor.getGenerationMode(); + return { generationMode, state, canvasManager }; + } else { + assert(false, `Unknown tab: ${tab}`); + } +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/types.ts b/invokeai/frontend/web/src/features/nodes/util/graph/types.ts index 16c0c0341c..03af86f64e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/types.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/types.ts @@ -1,3 +1,6 @@ +import type { RootState } from 'app/store/store'; +import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; +import type { GenerationMode } from 'features/controlLayers/store/types'; import type { FieldIdentifier } from 'features/nodes/types/field'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; @@ -27,6 +30,17 @@ export type MainModelLoaderNodes = export type VaeSourceNodes = 'seamless' | 'vae_loader'; +export type GraphBuilderArg = + | { + generationMode: Extract; + state: RootState; + } + | { + generationMode: Exclude; + state: RootState; + canvasManager: CanvasManager; + }; + export type GraphBuilderReturn = { g: Graph; seedFieldIdentifier?: FieldIdentifier; diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts index 6032a52d46..c54ab3e137 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueCanvas.ts @@ -7,7 +7,8 @@ import { extractMessageFromAssertionError } from 'common/util/extractMessageFrom import { withResult, withResultAsync } from 'common/util/result'; import { useCanvasManagerSafe } from 'features/controlLayers/contexts/CanvasManagerProviderGate'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import { canvasSessionIdCreated, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import { canvasSessionIdChanged, selectCanvasSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; @@ -18,6 +19,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; @@ -33,40 +35,42 @@ const enqueueCanvas = async (store: AppStore, canvasManager: CanvasManager, prep dispatch(enqueueRequestedCanvas()); - let destination = selectCanvasSessionId(getState()); - if (!destination) { - dispatch(canvasSessionIdCreated()); - destination = selectCanvasSessionId(getState()); - } - assert(destination !== null); - const state = getState(); - const model = state.params.model; - assert(model, 'No model found in state'); - const base = model.base; + let destination = selectCanvasSessionId(state); + if (destination === null) { + destination = getPrefixedId('canvas'); + dispatch(canvasSessionIdChanged({ id: destination })); + } const buildGraphResult = await withResultAsync(async () => { + const model = state.params.model; + assert(model, 'No model found in state'); + const base = model.base; + + const generationMode = await canvasManager.compositor.getGenerationMode(); + const graphBuilderArg: GraphBuilderArg = { generationMode, state, canvasManager }; + switch (base) { case 'sdxl': - return await buildSDXLGraph(state, canvasManager); + return await buildSDXLGraph(graphBuilderArg); case 'sd-1': case `sd-2`: - return await buildSD1Graph(state, canvasManager); + return await buildSD1Graph(graphBuilderArg); case `sd-3`: - return await buildSD3Graph(state, canvasManager); + return await buildSD3Graph(graphBuilderArg); case `flux`: - return await buildFLUXGraph(state, canvasManager); + return await buildFLUXGraph(graphBuilderArg); case 'cogview4': - return await buildCogView4Graph(state, canvasManager); + return await buildCogView4Graph(graphBuilderArg); case 'imagen3': - return await buildImagen3Graph(state, canvasManager); + return buildImagen3Graph(graphBuilderArg); case 'imagen4': - return await buildImagen4Graph(state, canvasManager); + return buildImagen4Graph(graphBuilderArg); case 'chatgpt-4o': - return await buildChatGPT4oGraph(state, canvasManager); + return await buildChatGPT4oGraph(graphBuilderArg); case 'flux-kontext': - return await buildFluxKontextGraph(state, canvasManager); + return buildFluxKontextGraph(graphBuilderArg); default: assert(false, `No graph builders for base ${base}`); } diff --git a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts index 9872861dc2..ab6172c828 100644 --- a/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts +++ b/invokeai/frontend/web/src/features/queue/hooks/useEnqueueGenerate.ts @@ -5,7 +5,8 @@ import type { AppStore } from 'app/store/store'; import { useAppStore } from 'app/store/storeHooks'; import { extractMessageFromAssertionError } from 'common/util/extractMessageFromAssertionError'; import { withResult, withResultAsync } from 'common/util/result'; -import { generateSessionIdCreated, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; +import { getPrefixedId } from 'features/controlLayers/konva/util'; +import { generateSessionIdChanged, selectGenerateSessionId } from 'features/controlLayers/store/canvasStagingAreaSlice'; import { prepareLinearUIBatch } from 'features/nodes/util/graph/buildLinearBatchConfig'; import { buildChatGPT4oGraph } from 'features/nodes/util/graph/generation/buildChatGPT4oGraph'; import { buildCogView4Graph } from 'features/nodes/util/graph/generation/buildCogView4Graph'; @@ -16,6 +17,7 @@ import { buildImagen4Graph } from 'features/nodes/util/graph/generation/buildIma import { buildSD1Graph } from 'features/nodes/util/graph/generation/buildSD1Graph'; import { buildSD3Graph } from 'features/nodes/util/graph/generation/buildSD3Graph'; import { buildSDXLGraph } from 'features/nodes/util/graph/generation/buildSDXLGraph'; +import type { GraphBuilderArg } from 'features/nodes/util/graph/types'; import { UnsupportedGenerationModeError } from 'features/nodes/util/graph/types'; import { toast } from 'features/toast/toast'; import { useCallback } from 'react'; @@ -32,40 +34,41 @@ const enqueueGenerate = async (store: AppStore, prepend: boolean) => { dispatch(enqueueRequestedGenerate()); - let destination = selectGenerateSessionId(getState()); - if (!destination) { - dispatch(generateSessionIdCreated()); - destination = selectGenerateSessionId(getState()); - } - assert(destination !== null); - const state = getState(); - const model = state.params.model; - assert(model, 'No model found in state'); - const base = model.base; + let destination = selectGenerateSessionId(state); + if (destination === null) { + destination = getPrefixedId('generate'); + dispatch(generateSessionIdChanged({ id: destination })); + } const buildGraphResult = await withResultAsync(async () => { + const model = state.params.model; + assert(model, 'No model found in state'); + const base = model.base; + + const graphBuilderArg: GraphBuilderArg = { generationMode: 'txt2img', state }; + switch (base) { case 'sdxl': - return await buildSDXLGraph(state, null); + return await buildSDXLGraph(graphBuilderArg); case 'sd-1': case `sd-2`: - return await buildSD1Graph(state, null); + return await buildSD1Graph(graphBuilderArg); case `sd-3`: - return await buildSD3Graph(state, null); + return await buildSD3Graph(graphBuilderArg); case `flux`: - return await buildFLUXGraph(state, null); + return await buildFLUXGraph(graphBuilderArg); case 'cogview4': - return await buildCogView4Graph(state, null); + return await buildCogView4Graph(graphBuilderArg); case 'imagen3': - return await buildImagen3Graph(state, null); + return buildImagen3Graph(graphBuilderArg); case 'imagen4': - return await buildImagen4Graph(state, null); + return buildImagen4Graph(graphBuilderArg); case 'chatgpt-4o': - return await buildChatGPT4oGraph(state, null); + return await buildChatGPT4oGraph(graphBuilderArg); case 'flux-kontext': - return await buildFluxKontextGraph(state, null); + return buildFluxKontextGraph(graphBuilderArg); default: assert(false, `No graph builders for base ${base}`); }