diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx index ed2bb86a88..c4462c0f4d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx @@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => { justifyContent="flex-start" leftIcon={} onClick={addRegionalGuidance} - isDisabled={isFLUX || isSD3} + isDisabled={isSD3} > {t('controlLayers.regionalGuidance')} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx index 70623f54b7..40c750bc52 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx @@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => { } onClick={addInpaintMask}> {t('controlLayers.inpaintMask')} - } onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}> + } onClick={addRegionalGuidance} isDisabled={isSD3}> {t('controlLayers.regionalGuidance')} } onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}> diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index dcce2046da..2e0495e47c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -50,14 +50,15 @@ export const addRegions = async ( g: Graph, bbox: Rect, base: BaseModelType, - denoise: Invocation<'denoise_latents'>, - posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, - negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, + denoise: Invocation<'denoise_latents' | 'flux_denoise'>, + posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, + negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null, posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'>, + negCondCollect: Invocation<'collect'> | null, ipAdapterCollect: Invocation<'collect'> ): Promise => { const isSDXL = base === 'sdxl'; + const isFLUX = base === 'flux'; const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const results: AddedRegionResult[] = []; @@ -94,20 +95,27 @@ export const addRegions = async ( if (region.positivePrompt) { // The main positive conditioning node result.addedPositivePrompt = true; - const regionalPosCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - } - ); + let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? + }); + } else if (isFLUX) { + regionalPosCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask'); // Connect the conditioning to the collector @@ -115,38 +123,55 @@ export const addRegions = async ( // Copy the connections to the "global" positive conditioning node to the regional cond if (posCond.type === 'compel') { for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'sdxl_compel_prompt') { + for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCond.id; g.addEdgeFromObj(clone); } } else { - for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalPosCond.id; - g.addEdgeFromObj(clone); - } + assert(false, 'Unsupported positive conditioning node type.'); } } if (region.negativePrompt) { - result.addedNegativePrompt = true; + assert(negCond, 'Negative conditioning node is required if there is a negative prompt'); + assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt'); + // The main negative conditioning node - const regionalNegCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - style: region.negativePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - } - ); + result.addedNegativePrompt = true; + let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalNegCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + style: region.negativePrompt, + }); + } else if (isFLUX) { + regionalNegCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } else { + regionalNegCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } + // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask'); // Connect the conditioning to the collector @@ -158,17 +183,27 @@ export const addRegions = async ( clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } - } else { + } else if (negCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } + } else if (negCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported negative conditioning node type.'); } } // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node if (region.autoNegative && region.positivePrompt) { + assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting'); + result.addedAutoNegativePositivePrompt = true; // We re-use the mask image, but invert it when converting to tensor const invertTensorMask = g.addNode({ @@ -178,20 +213,27 @@ export const addRegions = async ( // Connect the OG mask image to the inverted mask-to-tensor node g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask'); // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt - const regionalPosCondInverted = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - style: region.positivePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - } - ); + let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCondInverted = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + style: region.positivePrompt, + }); + } else if (isFLUX) { + regionalPosCondInverted = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCondInverted = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } // Connect the inverted mask to the conditioning g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask'); // Connect the conditioning to the negative collector @@ -203,12 +245,20 @@ export const addRegions = async ( clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } - } else { + } else if (posCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported positive conditioning node type.'); } } @@ -217,6 +267,8 @@ export const addRegions = async ( ); for (const { id, ipAdapter } of validRGIPAdapters) { + assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.'); + result.addedIPAdapters++; const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; assert(model, 'IP Adapter model is required'); @@ -250,7 +302,7 @@ export const addRegions = async ( }; const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image + // Must be a model that matches the current base and must have a control image const hasModel = Boolean(ipAdapter.model); const modelMatchesBase = ipAdapter.model?.base === base; const hasImage = Boolean(ipAdapter.image); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index d893760f3c..22be9f58e7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; +import { addRegions } from 'features/nodes/util/graph/generation/addRegions'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; @@ -79,7 +80,10 @@ export const buildFLUXGraph = async ( id: getPrefixedId('flux_text_encoder'), prompt: positivePrompt, }); - + const posCondCollect = g.addNode({ + type: 'collect', + id: getPrefixedId('pos_cond_collect'), + }); const denoise = g.addNode({ type: 'flux_denoise', id: getPrefixedId('flux_denoise'), @@ -104,13 +108,12 @@ export const buildFLUXGraph = async ( g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder'); g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning'); + g.addEdge(denoise, 'latents', l2i, 'latents'); addFLUXLoRAs(state, g, denoise, modelLoader, posCond); - g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning'); - - g.addEdge(denoise, 'latents', l2i, 'latents'); - const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); assert(modelConfig.base === 'flux'); @@ -216,7 +219,22 @@ export const buildFLUXGraph = async ( }); const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); - const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters; + const regionsResult = await addRegions( + manager, + canvas.regionalGuidance.entities, + g, + canvas.bbox.rect, + modelConfig.base, + denoise, + posCond, + null, + posCondCollect, + null, + ipAdapterCollector + ); + + const totalIPAdaptersAdded = + ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); } else { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 90adf7517e..8878a4bcfa 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -6564,6 +6564,11 @@ export type components = { * @description The name of conditioning tensor */ conditioning_name: string; + /** + * @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; }; /** * FluxConditioningOutput @@ -6771,15 +6776,17 @@ export type components = { */ transformer?: components["schemas"]["TransformerField"]; /** + * Positive Text Conditioning * @description Positive conditioning tensor * @default null */ - positive_text_conditioning?: components["schemas"]["FluxConditioningField"]; + positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][]; /** + * Negative Text Conditioning * @description Negative conditioning tensor. Can be None if cfg_scale is 1.0. * @default null */ - negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null; + negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null; /** * CFG Scale * @description Classifier-Free Guidance scale @@ -7133,6 +7140,11 @@ export type components = { * @default null */ prompt?: string; + /** + * @description A mask defining the region that this conditioning prompt applies to. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; /** * type * @default flux_text_encoder