diff --git a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx index 03fb57411a..400a0f845d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/ControlLayer/ControlLayerControlAdapter.tsx @@ -16,13 +16,13 @@ import { controlLayerModelChanged, controlLayerWeightChanged, } from 'features/controlLayers/store/canvasSlice'; +import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors'; import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types'; import { memo, useCallback, useMemo } from 'react'; import { useTranslation } from 'react-i18next'; import { PiBoundingBoxBold, PiShootingStarBold, PiUploadBold } from 'react-icons/pi'; import type { ControlNetModelConfig, PostUploadAction, T2IAdapterModelConfig } from 'services/api/types'; -import { selectIsFLUX } from '../../store/paramsSlice'; const useControlLayerControlAdapter = (entityIdentifier: CanvasEntityIdentifier<'control_layer'>) => { const selectControlAdapter = useMemo( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index 347ca4fba4..908f912358 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -110,16 +110,21 @@ const addControlNetToGraph = ( const controlNet = g.addNode({ id: `control_net_${id}`, - type: 'controlnet', + type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet', begin_step_percent: beginEndStepPct[0], end_step_percent: beginEndStepPct[1], - control_mode: controlMode, + control_mode: model.base === 'flux' ? undefined : controlMode, resize_mode: 'just_resize', control_model: model, control_weight: weight, image: { image_name }, }); - g.addEdge(controlNet, 'control', collector, 'item'); + + if (controlNet.type === 'flux_controlnet') { + g.addEdge(controlNet, 'controlnet', collector, 'item'); + } else { + g.addEdge(controlNet, 'control', collector, 'item'); + } }; const addT2IAdapterToGraph = ( diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 50e55526b0..d58f38b8c4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; +import { addControlNets } from './addControlAdapters'; + const log = logger('system'); export const buildFLUXGraph = async ( @@ -177,6 +179,24 @@ export const buildFLUXGraph = async ( ); } + const controlNetCollector = g.addNode({ + type: 'collect', + id: getPrefixedId('control_net_collector'), + }); + const controlNetResult = await addControlNets( + manager, + canvas.controlLayers.entities, + g, + canvas.bbox.rect, + controlNetCollector, + modelConfig.base + ); + if (controlNetResult.addedControlNets > 0) { + g.addEdge(controlNetCollector, 'collection', noise, 'controlnet'); + } else { + g.deleteNode(controlNetCollector.id); + } + if (state.system.shouldUseNSFWChecker) { canvasOutput = addNSFWChecker(g, canvasOutput); }