diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py
index 5e76931933..16bfcaeed7 100644
--- a/invokeai/app/invocations/fields.py
+++ b/invokeai/app/invocations/fields.py
@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
+ mask: Optional[TensorField] = Field(
+ default=None,
+ description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
+ "included regions should be set to True.",
+ )
class SD3ConditioningField(BaseModel):
diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py
index 9e197626b5..f2a2dd586a 100644
--- a/invokeai/app/invocations/flux_denoise.py
+++ b/invokeai/app/invocations/flux_denoise.py
@@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
+from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
@@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
+from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
@@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
- positive_text_conditioning: FluxConditioningField = InputField(
+ positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
- negative_text_conditioning: FluxConditioningField | None = InputField(
+ negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
@@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
- def _load_text_conditioning(
- self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Load the conditioning data.
- cond_data = context.conditioning.load(conditioning_name)
- assert len(cond_data.conditionings) == 1
- flux_conditioning = cond_data.conditionings[0]
- assert isinstance(flux_conditioning, FLUXConditioningInfo)
- flux_conditioning = flux_conditioning.to(dtype=dtype)
- t5_embeddings = flux_conditioning.t5_embeds
- clip_embeddings = flux_conditioning.clip_embeds
- return t5_embeddings, clip_embeddings
-
def _run_diffusion(
self,
context: InvocationContext,
):
inference_dtype = torch.bfloat16
- # Load the conditioning data.
- pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
- context, self.positive_text_conditioning.conditioning_name, inference_dtype
- )
- neg_t5_embeddings: torch.Tensor | None = None
- neg_clip_embeddings: torch.Tensor | None = None
- if self.negative_text_conditioning is not None:
- neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
- context, self.negative_text_conditioning.conditioning_name, inference_dtype
- )
-
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
@@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
seed=self.seed,
)
+ b, _c, latent_h, latent_w = noise.shape
+ packed_h = latent_h // 2
+ packed_w = latent_w // 2
+
+ # Load the conditioning data.
+ pos_text_conditionings = self._load_text_conditioning(
+ context=context,
+ cond_field=self.positive_text_conditioning,
+ packed_height=packed_h,
+ packed_width=packed_w,
+ dtype=inference_dtype,
+ device=TorchDevice.choose_torch_device(),
+ )
+ neg_text_conditionings: list[FluxTextConditioning] | None = None
+ if self.negative_text_conditioning is not None:
+ neg_text_conditionings = self._load_text_conditioning(
+ context=context,
+ cond_field=self.negative_text_conditioning,
+ packed_height=packed_h,
+ packed_width=packed_w,
+ dtype=inference_dtype,
+ device=TorchDevice.choose_torch_device(),
+ )
+ pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(
+ pos_text_conditionings, img_seq_len=packed_h * packed_w
+ )
+ neg_regional_prompting_extension = (
+ RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w)
+ if neg_text_conditionings
+ else None
+ )
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
# Calculate the timestep schedule.
- image_seq_len = noise.shape[-1] * noise.shape[-2] // 4
timesteps = get_schedule(
num_steps=self.num_steps,
- image_seq_len=image_seq_len,
+ image_seq_len=packed_h * packed_w,
shift=not is_schnell,
)
@@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
- b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
- pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
- pos_txt_ids = torch.zeros(
- pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
- )
- neg_txt_ids: torch.Tensor | None = None
- if neg_t5_embeddings is not None:
- neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
- neg_txt_ids = torch.zeros(
- neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
- )
-
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
noise = pack(noise)
x = pack(x)
- # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly.
- assert image_seq_len == x.shape[1]
+ # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and
+ # packed_w correctly.
+ assert packed_h * packed_w == x.shape[1]
# Prepare inpaint extension.
inpaint_extension: InpaintExtension | None = None
@@ -338,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
img=x,
img_ids=img_ids,
- txt=pos_t5_embeddings,
- txt_ids=pos_txt_ids,
- vec=pos_clip_embeddings,
- neg_txt=neg_t5_embeddings,
- neg_txt_ids=neg_txt_ids,
- neg_vec=neg_clip_embeddings,
+ pos_regional_prompting_extension=pos_regional_prompting_extension,
+ neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
@@ -357,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
+ def _load_text_conditioning(
+ self,
+ context: InvocationContext,
+ cond_field: FluxConditioningField | list[FluxConditioningField],
+ packed_height: int,
+ packed_width: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> list[FluxTextConditioning]:
+ """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
+ # Normalize to a list of FluxConditioningFields.
+ cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
+
+ text_conditionings: list[FluxTextConditioning] = []
+ for cond_field in cond_list:
+ # Load the text embeddings.
+ cond_data = context.conditioning.load(cond_field.conditioning_name)
+ assert len(cond_data.conditionings) == 1
+ flux_conditioning = cond_data.conditionings[0]
+ assert isinstance(flux_conditioning, FLUXConditioningInfo)
+ flux_conditioning = flux_conditioning.to(dtype=dtype, device=device)
+ t5_embeddings = flux_conditioning.t5_embeds
+ clip_embeddings = flux_conditioning.clip_embeds
+
+ # Load the mask, if provided.
+ mask: Optional[torch.Tensor] = None
+ if cond_field.mask is not None:
+ mask = context.tensors.load(cond_field.mask.tensor_name)
+ mask = mask.to(device=device)
+ mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(
+ mask, packed_height, packed_width, dtype, device
+ )
+
+ text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
+
+ return text_conditionings
+
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int
diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py
index 91c89cb31b..1eb0fea62e 100644
--- a/invokeai/app/invocations/flux_text_encoder.py
+++ b/invokeai/app/invocations/flux_text_encoder.py
@@ -1,11 +1,18 @@
from contextlib import ExitStack
-from typing import Iterator, Literal, Tuple
+from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
-from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent
+from invokeai.app.invocations.fields import (
+ FieldDescriptions,
+ FluxConditioningField,
+ Input,
+ InputField,
+ TensorField,
+ UIComponent,
+)
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -41,9 +48,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
t5_max_seq_len: Literal[256, 512] = InputField(
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
- prompt: str = InputField(
- description="Text prompt to encode.",
- ui_component=UIComponent.Textarea,
+ prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea)
+ mask: Optional[TensorField] = InputField(
+ default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
@@ -57,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
conditioning_name = context.conditioning.save(conditioning_data)
- return FluxConditioningOutput.build(conditioning_name)
+ return FluxConditioningOutput(
+ conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
+ )
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)
diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py
index e0c7779e93..0f56adacde 100644
--- a/invokeai/backend/flux/custom_block_processor.py
+++ b/invokeai/backend/flux/custom_block_processor.py
@@ -1,9 +1,10 @@
import einops
import torch
+from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
-from invokeai.backend.flux.modules.layers import DoubleStreamBlock
+from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock
class CustomDoubleStreamBlockProcessor:
@@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor:
@staticmethod
def _double_stream_block_forward(
- block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor
+ block: DoubleStreamBlock,
+ img: torch.Tensor,
+ txt: torch.Tensor,
+ vec: torch.Tensor,
+ pe: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate
values.
@@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor:
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
- attn = attention(q, k, v, pe=pe)
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
@@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor:
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
+ regional_prompting_extension: RegionalPromptingExtension,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support
"""
- img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe)
+ attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index)
+ img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(
+ block, img, txt, vec, pe, attn_mask=attn_mask
+ )
# Apply IP-Adapter conditioning.
for ip_adapter_extension in ip_adapter_extensions:
@@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor:
)
return img, txt
+
+
+class CustomSingleStreamBlockProcessor:
+ """A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking,
+ etc.)
+ """
+
+ @staticmethod
+ def _single_stream_block_forward(
+ block: SingleStreamBlock,
+ x: torch.Tensor,
+ vec: torch.Tensor,
+ pe: torch.Tensor,
+ attn_mask: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """This function is a direct copy of SingleStreamBlock.forward()."""
+ mod, _ = block.modulation(vec)
+ x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1)
+
+ q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads)
+ q, k = block.norm(q, k, v)
+
+ # compute attention
+ attn = attention(q, k, v, pe=pe, attn_mask=attn_mask)
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2))
+ return x + mod.gate * output
+
+ @staticmethod
+ def custom_single_block_forward(
+ timestep_index: int,
+ total_num_timesteps: int,
+ block_index: int,
+ block: SingleStreamBlock,
+ img: torch.Tensor,
+ vec: torch.Tensor,
+ pe: torch.Tensor,
+ regional_prompting_extension: RegionalPromptingExtension,
+ ) -> torch.Tensor:
+ """A custom implementation of SingleStreamBlock.forward() with additional features:
+ - Masking
+ """
+ attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index)
+ return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask)
diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py
index bb0e60409a..66e3984edb 100644
--- a/invokeai/backend/flux/denoise.py
+++ b/invokeai/backend/flux/denoise.py
@@ -7,6 +7,7 @@ from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
+from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
@@ -18,14 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
- # positive text conditioning
- txt: torch.Tensor,
- txt_ids: torch.Tensor,
- vec: torch.Tensor,
- # negative text conditioning
- neg_txt: torch.Tensor | None,
- neg_txt_ids: torch.Tensor | None,
- neg_vec: torch.Tensor | None,
+ pos_regional_prompting_extension: RegionalPromptingExtension,
+ neg_regional_prompting_extension: RegionalPromptingExtension | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -61,9 +56,9 @@ def denoise(
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
- txt=txt,
- txt_ids=txt_ids,
- y=vec,
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
@@ -78,9 +73,9 @@ def denoise(
pred = model(
img=img,
img_ids=img_ids,
- txt=txt,
- txt_ids=txt_ids,
- y=vec,
+ txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
+ txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
+ y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -88,6 +83,7 @@ def denoise(
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
+ regional_prompting_extension=pos_regional_prompting_extension,
)
step_cfg_scale = cfg_scale[step_index]
@@ -97,15 +93,15 @@ def denoise(
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
- if neg_txt is None or neg_txt_ids is None or neg_vec is None:
+ if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,
- txt=neg_txt,
- txt_ids=neg_txt_ids,
- y=neg_vec,
+ txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
+ txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
+ y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -113,6 +109,7 @@ def denoise(
controlnet_double_block_residuals=None,
controlnet_single_block_residuals=None,
ip_adapter_extensions=neg_ip_adapter_extensions,
+ regional_prompting_extension=neg_regional_prompting_extension,
)
pred = neg_pred + step_cfg_scale * (pred - neg_pred)
diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py
new file mode 100644
index 0000000000..f1086c3286
--- /dev/null
+++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py
@@ -0,0 +1,276 @@
+from typing import Optional
+
+import torch
+import torchvision
+
+from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
+from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
+from invokeai.backend.util.devices import TorchDevice
+from invokeai.backend.util.mask import to_standard_float_mask
+
+
+class RegionalPromptingExtension:
+ """A class for managing regional prompting with FLUX.
+
+ This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences).
+ """
+
+ def __init__(
+ self,
+ regional_text_conditioning: FluxRegionalTextConditioning,
+ restricted_attn_mask: torch.Tensor | None = None,
+ ):
+ self.regional_text_conditioning = regional_text_conditioning
+ self.restricted_attn_mask = restricted_attn_mask
+
+ def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
+ order = [self.restricted_attn_mask, None]
+ return order[block_index % len(order)]
+
+ def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None:
+ order = [self.restricted_attn_mask, None]
+ return order[block_index % len(order)]
+
+ @classmethod
+ def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int):
+ """Create a RegionalPromptingExtension from a list of text conditionings.
+
+ Args:
+ text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting.
+ img_seq_len (int): The image sequence length (i.e. packed_height * packed_width).
+ """
+ regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning)
+ attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask(
+ regional_text_conditioning, img_seq_len
+ )
+ return cls(
+ regional_text_conditioning=regional_text_conditioning,
+ restricted_attn_mask=attn_mask_with_restricted_img_self_attn,
+ )
+
+ # Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy:
+ #
+ # @classmethod
+ # def _prepare_unrestricted_attn_mask(
+ # cls,
+ # regional_text_conditioning: FluxRegionalTextConditioning,
+ # img_seq_len: int,
+ # ) -> torch.Tensor:
+ # """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that:
+ # - img self-attention is not masked.
+ # - img regions attend to both txt within their own region and to global prompts.
+ # """
+ # device = TorchDevice.choose_torch_device()
+
+ # # Infer txt_seq_len from the t5_embeddings tensor.
+ # txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
+
+ # # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
+ # # Concatenation happens in the following order: [txt_seq, img_seq].
+ # # There are 4 portions of the attention mask to consider as we prepare it:
+ # # 1. txt attends to itself
+ # # 2. txt attends to corresponding regional img
+ # # 3. regional img attends to corresponding txt
+ # # 4. regional img attends to itself
+
+ # # Initialize empty attention mask.
+ # regional_attention_mask = torch.zeros(
+ # (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
+ # )
+
+ # for image_mask, t5_embedding_range in zip(
+ # regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
+ # ):
+ # # 1. txt attends to itself
+ # regional_attention_mask[
+ # t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
+ # ] = 1.0
+
+ # # 2. txt attends to corresponding regional img
+ # # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
+ # fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0
+ # regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value
+
+ # # 3. regional img attends to corresponding txt
+ # # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
+ # fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0
+ # regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value
+
+ # # 4. regional img attends to itself
+ # # Allow unrestricted img self attention.
+ # regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0
+
+ # # Convert attention mask to boolean.
+ # regional_attention_mask = regional_attention_mask > 0.5
+
+ # return regional_attention_mask
+
+ @classmethod
+ def _prepare_restricted_attn_mask(
+ cls,
+ regional_text_conditioning: FluxRegionalTextConditioning,
+ img_seq_len: int,
+ ) -> torch.Tensor | None:
+ """Prepare a 'restricted' attention mask. In this context, 'restricted' means that:
+ - img self-attention is only allowed within regions.
+ - img regions only attend to txt within their own region, not to global prompts.
+ """
+ # Identify background region. I.e. the region that is not covered by any region masks.
+ background_region_mask: None | torch.Tensor = None
+ for image_mask in regional_text_conditioning.image_masks:
+ if image_mask is not None:
+ if background_region_mask is None:
+ background_region_mask = torch.ones_like(image_mask)
+ background_region_mask *= 1 - image_mask
+
+ if background_region_mask is None:
+ # There are no region masks, short-circuit and return None.
+ # TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would
+ # is a rare use case and would make the logic here significantly more complicated.
+ return None
+
+ device = TorchDevice.choose_torch_device()
+
+ # Infer txt_seq_len from the t5_embeddings tensor.
+ txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1]
+
+ # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied.
+ # Concatenation happens in the following order: [txt_seq, img_seq].
+ # There are 4 portions of the attention mask to consider as we prepare it:
+ # 1. txt attends to itself
+ # 2. txt attends to corresponding regional img
+ # 3. regional img attends to corresponding txt
+ # 4. regional img attends to itself
+
+ # Initialize empty attention mask.
+ regional_attention_mask = torch.zeros(
+ (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16
+ )
+
+ for image_mask, t5_embedding_range in zip(
+ regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True
+ ):
+ # 1. txt attends to itself
+ regional_attention_mask[
+ t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end
+ ] = 1.0
+
+ if image_mask is not None:
+ # 2. txt attends to corresponding regional img
+ # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
+ regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
+ image_mask.view(1, img_seq_len)
+ )
+
+ # 3. regional img attends to corresponding txt
+ # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
+ regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
+ image_mask.view(img_seq_len, 1)
+ )
+
+ # 4. regional img attends to itself
+ image_mask = image_mask.view(img_seq_len, 1)
+ regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T
+ else:
+ # We don't allow attention between non-background image regions and global prompts. This helps to ensure
+ # that regions focus on their local prompts. We do, however, allow attention between background regions
+ # and global prompts. If we didn't do this, then the background regions would not attend to any txt
+ # embeddings, which we found experimentally to cause artifacts.
+
+ # 2. global txt attends to background region
+ # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired.
+ regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = (
+ background_region_mask.view(1, img_seq_len)
+ )
+
+ # 3. background region attends to global txt
+ # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired.
+ regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = (
+ background_region_mask.view(img_seq_len, 1)
+ )
+
+ # Allow background regions to attend to themselves.
+ regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1)
+ regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len)
+
+ # Convert attention mask to boolean.
+ regional_attention_mask = regional_attention_mask > 0.5
+
+ return regional_attention_mask
+
+ @classmethod
+ def _concat_regional_text_conditioning(
+ cls,
+ text_conditionings: list[FluxTextConditioning],
+ ) -> FluxRegionalTextConditioning:
+ """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
+ concat_t5_embeddings: list[torch.Tensor] = []
+ concat_t5_embedding_ranges: list[Range] = []
+ image_masks: list[torch.Tensor | None] = []
+
+ # Choose global CLIP embedding.
+ # Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use
+ # the first prompt's CLIP embedding.
+ global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings
+ for text_conditioning in text_conditionings:
+ if text_conditioning.mask is None:
+ global_clip_embedding = text_conditioning.clip_embeddings
+ break
+
+ cur_t5_embedding_len = 0
+ for text_conditioning in text_conditionings:
+ concat_t5_embeddings.append(text_conditioning.t5_embeddings)
+
+ concat_t5_embedding_ranges.append(
+ Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
+ )
+
+ image_masks.append(text_conditioning.mask)
+
+ cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
+
+ t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
+
+ # Initialize the txt_ids tensor.
+ pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
+ t5_txt_ids = torch.zeros(
+ pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
+ )
+
+ return FluxRegionalTextConditioning(
+ t5_embeddings=t5_embeddings,
+ clip_embeddings=global_clip_embedding,
+ t5_txt_ids=t5_txt_ids,
+ image_masks=image_masks,
+ t5_embedding_ranges=concat_t5_embedding_ranges,
+ )
+
+ @staticmethod
+ def preprocess_regional_prompt_mask(
+ mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device
+ ) -> torch.Tensor:
+ """Preprocess a regional prompt mask to match the target height and width.
+ If mask is None, returns a mask of all ones with the target height and width.
+ If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
+
+ packed_height and packed_width are the target height and width of the mask in the 'packed' latent space.
+
+ Returns:
+ torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width).
+ """
+
+ if mask is None:
+ return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device)
+
+ mask = to_standard_float_mask(mask, out_dtype=dtype)
+
+ tf = torchvision.transforms.Resize(
+ (packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
+ )
+
+ # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
+ mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
+ resized_mask = tf(mask)
+
+ # Flatten the height and width dimensions into a single image_seq_len dimension.
+ return resized_mask.flatten(start_dim=2)
diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py
index 0fac7a1d16..260243a35d 100644
--- a/invokeai/backend/flux/math.py
+++ b/invokeai/backend/flux/math.py
@@ -5,10 +5,10 @@ from einops import rearrange
from torch import Tensor
-def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
+def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor:
q, k = apply_rope(q, k, pe)
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "B H L D -> B L (H D)")
return x
diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py
index 0dadacd8fe..0add6fd4d7 100644
--- a/invokeai/backend/flux/model.py
+++ b/invokeai/backend/flux/model.py
@@ -5,7 +5,11 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
-from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
+from invokeai.backend.flux.custom_block_processor import (
+ CustomDoubleStreamBlockProcessor,
+ CustomSingleStreamBlockProcessor,
+)
+from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
@@ -95,6 +99,7 @@ class Flux(nn.Module):
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
+ regional_prompting_extension: RegionalPromptingExtension,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -117,7 +122,6 @@ class Flux(nn.Module):
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
assert isinstance(block, DoubleStreamBlock)
-
img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward(
timestep_index=timestep_index,
total_num_timesteps=total_num_timesteps,
@@ -128,6 +132,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
ip_adapter_extensions=ip_adapter_extensions,
+ regional_prompting_extension=regional_prompting_extension,
)
if controlnet_double_block_residuals is not None:
@@ -140,7 +145,17 @@ class Flux(nn.Module):
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
- img = block(img, vec=vec, pe=pe)
+ assert isinstance(block, SingleStreamBlock)
+ img = CustomSingleStreamBlockProcessor.custom_single_block_forward(
+ timestep_index=timestep_index,
+ total_num_timesteps=total_num_timesteps,
+ block_index=block_index,
+ block=block,
+ img=img,
+ vec=vec,
+ pe=pe,
+ regional_prompting_extension=regional_prompting_extension,
+ )
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py
new file mode 100644
index 0000000000..5bc8b4d041
--- /dev/null
+++ b/invokeai/backend/flux/text_conditioning.py
@@ -0,0 +1,36 @@
+from dataclasses import dataclass
+
+import torch
+
+from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
+
+
+@dataclass
+class FluxTextConditioning:
+ t5_embeddings: torch.Tensor
+ clip_embeddings: torch.Tensor
+ # If mask is None, the prompt is a global prompt.
+ mask: torch.Tensor | None
+
+
+@dataclass
+class FluxRegionalTextConditioning:
+ # Concatenated text embeddings.
+ # Shape: (1, concatenated_txt_seq_len, 4096)
+ t5_embeddings: torch.Tensor
+ # Shape: (1, concatenated_txt_seq_len, 3)
+ t5_txt_ids: torch.Tensor
+
+ # Global CLIP embeddings.
+ # Shape: (1, 768)
+ clip_embeddings: torch.Tensor
+
+ # A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a
+ # global prompt.
+ # image_masks[i] is the mask for the ith prompt.
+ # image_masks[i] has shape (1, image_seq_len) and dtype torch.bool.
+ image_masks: list[torch.Tensor | None]
+
+ # List of ranges that represent the embedding ranges for each mask.
+ # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
+ t5_embedding_ranges: list[Range]
diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json
index d8957c0b1e..3b04bcff71 100644
--- a/invokeai/frontend/web/public/locales/en.json
+++ b/invokeai/frontend/web/public/locales/en.json
@@ -176,7 +176,8 @@
"reset": "Reset",
"none": "None",
"new": "New",
- "generating": "Generating"
+ "generating": "Generating",
+ "warnings": "Warnings"
},
"hrf": {
"hrf": "High Resolution Fix",
@@ -1040,6 +1041,7 @@
"noNodesInGraph": "No nodes in graph",
"systemDisconnected": "System disconnected",
"layer": {
+ "unsupportedModel": "layer not supported for selected base model",
"controlAdapterNoModelSelected": "no Control Adapter model selected",
"controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model",
"t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}",
@@ -1050,7 +1052,10 @@
"ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model",
"ipAdapterNoImageSelected": "no IP Adapter image selected",
"rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters",
- "rgNoRegion": "no region selected"
+ "rgNegativePromptNotSupported": "negative prompt not supported for selected base model",
+ "rgReferenceImagesNotSupported": "regional reference images not supported for selected base model",
+ "rgAutoNegativeNotSupported": "auto-negative not supported for selected base model",
+ "emptyLayer": "empty layer"
}
},
"maskBlur": "Mask Blur",
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx
index ed2bb86a88..c4462c0f4d 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx
@@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={}
onClick={addRegionalGuidance}
- isDisabled={isFLUX || isSD3}
+ isDisabled={isSD3}
>
{t('controlLayers.regionalGuidance')}
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx
index 70623f54b7..40c750bc52 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx
@@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
} onClick={addInpaintMask}>
{t('controlLayers.inpaintMask')}
- } onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}>
+ } onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx
index 2fc7483756..3e11bdc4e2 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx
@@ -1,27 +1,28 @@
-import { IconButton, Tooltip } from '@invoke-ai/ui-library';
+import type { IconButtonProps } from '@invoke-ai/ui-library';
+import { IconButton } from '@invoke-ai/ui-library';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
-import { PiTrashSimpleFill } from 'react-icons/pi';
+import { PiXBold } from 'react-icons/pi';
-type Props = {
+type Props = Omit & {
onDelete: () => void;
};
-export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => {
+export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => {
const { t } = useTranslation();
return (
-
- }
- onClick={onDelete}
- flexGrow={0}
- size="sm"
- p={0}
- colorScheme="error"
- />
-
+ }
+ onClick={onDelete}
+ flexGrow={0}
+ size="sm"
+ p={0}
+ colorScheme="error"
+ {...rest}
+ />
);
});
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx
index c98d6bc23d..722720dfbf 100644
--- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx
+++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx
@@ -1,8 +1,10 @@
import { Button, Flex, Text } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { useImageUploadButton } from 'common/hooks/useImageUploadButton';
+import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton';
import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
+import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice';
import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd';
import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd';
import { DndDropTarget } from 'features/dnd/DndDropTarget';
@@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
const onClickGalleryButton = useCallback(() => {
dispatch(activeTabCanvasRightPanelChanged('gallery'));
}, [dispatch]);
+ const onDeleteIPAdapter = useCallback(() => {
+ dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId }));
+ }, [dispatch, entityIdentifier, referenceImageId]);
const dndTargetData = useMemo(
() =>
@@ -43,6 +48,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag
return (
+
{
return (
+
{entityIdentifier.type !== 'reference_image' && }
diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx
new file mode 100644
index 0000000000..f13de15c19
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx
@@ -0,0 +1,101 @@
+import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library';
+import { createSelector } from '@reduxjs/toolkit';
+import { EMPTY_ARRAY } from 'app/store/constants';
+import { useAppSelector } from 'app/store/storeHooks';
+import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext';
+import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled';
+import { selectModel } from 'features/controlLayers/store/paramsSlice';
+import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
+import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types';
+import {
+ getControlLayerWarnings,
+ getGlobalReferenceImageWarnings,
+ getInpaintMaskWarnings,
+ getRasterLayerWarnings,
+ getRegionalGuidanceWarnings,
+} from 'features/controlLayers/store/validators';
+import type { TFunction } from 'i18next';
+import { upperFirst } from 'lodash-es';
+import { memo, useMemo } from 'react';
+import { useTranslation } from 'react-i18next';
+import { PiWarningBold } from 'react-icons/pi';
+import type { Equals } from 'tsafe';
+import { assert } from 'tsafe';
+
+const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => {
+ return createSelector(selectCanvasSlice, selectModel, (canvas, model) => {
+ // This component is used within a so we can safely assume that the entity exists.
+ // Should never throw.
+ const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings');
+
+ let warnings: string[] = [];
+
+ const entityType = entity.type;
+
+ if (entityType === 'control_layer') {
+ warnings = getControlLayerWarnings(entity, model);
+ } else if (entityType === 'regional_guidance') {
+ warnings = getRegionalGuidanceWarnings(entity, model);
+ } else if (entityType === 'inpaint_mask') {
+ warnings = getInpaintMaskWarnings(entity, model);
+ } else if (entityType === 'raster_layer') {
+ warnings = getRasterLayerWarnings(entity, model);
+ } else if (entityType === 'reference_image') {
+ warnings = getGlobalReferenceImageWarnings(entity, model);
+ } else {
+ assert>(false, 'Unexpected entity type');
+ }
+
+ // Return a stable reference if there are no warnings
+ if (warnings.length === 0) {
+ return EMPTY_ARRAY;
+ }
+
+ return warnings.map((w) => t(w)).map(upperFirst);
+ });
+};
+
+export const CanvasEntityHeaderWarnings = memo(() => {
+ const entityIdentifier = useEntityIdentifierContext();
+ const { t } = useTranslation();
+ const isEnabled = useEntityIsEnabled(entityIdentifier);
+ const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]);
+ const warnings = useAppSelector(selectWarnings);
+
+ if (warnings.length === 0) {
+ return null;
+ }
+
+ return (
+ // Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but
+ // it's not a button
+ }
+ icon={}
+ colorScheme="warning"
+ isDisabled={!isEnabled}
+ />
+ );
+});
+
+CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings';
+
+const TooltipContent = memo((props: { warnings: string[] }) => {
+ const { t } = useTranslation();
+ return (
+
+ {t('common.warnings')}:
+
+ {props.warnings.map((warning, index) => (
+ {warning}
+ ))}
+
+
+ );
+});
+TooltipContent.displayName = 'TooltipContent';
diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts
new file mode 100644
index 0000000000..0b8a34c3bd
--- /dev/null
+++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts
@@ -0,0 +1,133 @@
+import type {
+ CanvasControlLayerState,
+ CanvasInpaintMaskState,
+ CanvasRasterLayerState,
+ CanvasReferenceImageState,
+ CanvasRegionalGuidanceState,
+} from 'features/controlLayers/store/types';
+import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
+
+export const getRegionalGuidanceWarnings = (
+ entity: CanvasRegionalGuidanceState,
+ model: ParameterModel | null
+): string[] => {
+ const warnings: string[] = [];
+
+ if (entity.objects.length === 0) {
+ // Layer is in empty state - skip other checks
+ warnings.push('parameters.invoke.layer.emptyLayer');
+ } else {
+ if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
+ // Must have at least 1 prompt or IP Adapter
+ warnings.push('parameters.invoke.layer.rgNoPromptsOrIPAdapters');
+ }
+
+ if (model) {
+ if (model.base === 'sd-3' || model.base === 'sd-2') {
+ // Unsupported model architecture
+ warnings.push('parameters.invoke.layer.unsupportedModel');
+ } else if (model.base === 'flux') {
+ // Some features are not supported for flux models
+ if (entity.negativePrompt !== null) {
+ warnings.push('parameters.invoke.layer.rgNegativePromptNotSupported');
+ }
+ if (entity.referenceImages.length > 0) {
+ warnings.push('parameters.invoke.layer.rgReferenceImagesNotSupported');
+ }
+ if (entity.autoNegative) {
+ warnings.push('parameters.invoke.layer.rgAutoNegativeNotSupported');
+ }
+ } else {
+ entity.referenceImages.forEach(({ ipAdapter }) => {
+ if (!ipAdapter.model) {
+ // No model selected
+ warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected');
+ } else if (ipAdapter.model.base !== model.base) {
+ // Supported model architecture but doesn't match
+ warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel');
+ }
+
+ if (!ipAdapter.image) {
+ // No image selected
+ warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected');
+ }
+ });
+ }
+ }
+ }
+
+ return warnings;
+};
+
+export const getGlobalReferenceImageWarnings = (
+ entity: CanvasReferenceImageState,
+ model: ParameterModel | null
+): string[] => {
+ const warnings: string[] = [];
+
+ if (!entity.ipAdapter.model) {
+ // No model selected
+ warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected');
+ } else if (model) {
+ if (model.base === 'sd-3' || model.base === 'sd-2') {
+ // Unsupported model architecture
+ warnings.push('parameters.invoke.layer.unsupportedModel');
+ } else if (entity.ipAdapter.model.base !== model.base) {
+ // Supported model architecture but doesn't match
+ warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel');
+ }
+ }
+
+ if (!entity.ipAdapter.image) {
+ // No image selected
+ warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected');
+ }
+
+ return warnings;
+};
+
+export const getControlLayerWarnings = (entity: CanvasControlLayerState, model: ParameterModel | null): string[] => {
+ const warnings: string[] = [];
+
+ if (entity.objects.length === 0) {
+ // Layer is in empty state - skip other checks
+ warnings.push('parameters.invoke.layer.emptyLayer');
+ } else {
+ if (!entity.controlAdapter.model) {
+ // No model selected
+ warnings.push('parameters.invoke.layer.controlAdapterNoModelSelected');
+ } else if (model) {
+ if (model.base === 'sd-3' || model.base === 'sd-2') {
+ // Unsupported model architecture
+ warnings.push('parameters.invoke.layer.unsupportedModel');
+ } else if (entity.controlAdapter.model.base !== model.base) {
+ // Supported model architecture but doesn't match
+ warnings.push('parameters.invoke.layer.controlAdapterIncompatibleBaseModel');
+ }
+ }
+ }
+
+ return warnings;
+};
+
+export const getRasterLayerWarnings = (entity: CanvasRasterLayerState, _model: ParameterModel | null): string[] => {
+ const warnings: string[] = [];
+
+ if (entity.objects.length === 0) {
+ // Layer is in empty state - skip other checks
+ warnings.push('parameters.invoke.layer.emptyLayer');
+ }
+
+ return warnings;
+};
+
+export const getInpaintMaskWarnings = (entity: CanvasInpaintMaskState, _model: ParameterModel | null): string[] => {
+ const warnings: string[] = [];
+
+ if (entity.objects.length === 0) {
+ // Layer is in empty state - skip other checks
+ warnings.push('parameters.invoke.layer.emptyLayer');
+ }
+
+ return warnings;
+};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
index d012ca2853..a41dc06550 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts
@@ -1,35 +1,41 @@
import { logger } from 'app/logging/logger';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
-import type {
- CanvasControlLayerState,
- ControlNetConfig,
- Rect,
- T2IAdapterConfig,
-} from 'features/controlLayers/store/types';
+import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types';
+import { getControlLayerWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
+import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
-import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types';
+import type { ImageDTO, Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
+type AddControlNetsArg = {
+ manager: CanvasManager;
+ entities: CanvasControlLayerState[];
+ g: Graph;
+ rect: Rect;
+ collector: Invocation<'collect'>;
+ model: ParameterModel;
+};
+
type AddControlNetsResult = {
addedControlNets: number;
};
-export const addControlNets = async (
- manager: CanvasManager,
- layers: CanvasControlLayerState[],
- g: Graph,
- rect: Rect,
- collector: Invocation<'collect'>,
- base: BaseModelType
-): Promise => {
- const validControlLayers = layers
- .filter((layer) => layer.isEnabled)
- .filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
- .filter((layer) => layer.controlAdapter.type === 'controlnet');
+export const addControlNets = async ({
+ manager,
+ entities,
+ g,
+ rect,
+ collector,
+ model,
+}: AddControlNetsArg): Promise => {
+ const validControlLayers = entities
+ .filter((entity) => entity.isEnabled)
+ .filter((entity) => entity.controlAdapter.type === 'controlnet')
+ .filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddControlNetsResult = {
addedControlNets: 0,
@@ -54,22 +60,31 @@ export const addControlNets = async (
return result;
};
+type AddT2IAdaptersArg = {
+ manager: CanvasManager;
+ entities: CanvasControlLayerState[];
+ g: Graph;
+ rect: Rect;
+ collector: Invocation<'collect'>;
+ model: ParameterModel;
+};
+
type AddT2IAdaptersResult = {
addedT2IAdapters: number;
};
-export const addT2IAdapters = async (
- manager: CanvasManager,
- layers: CanvasControlLayerState[],
- g: Graph,
- rect: Rect,
- collector: Invocation<'collect'>,
- base: BaseModelType
-): Promise => {
- const validControlLayers = layers
- .filter((layer) => layer.isEnabled)
- .filter((layer) => isValidControlAdapter(layer.controlAdapter, base))
- .filter((layer) => layer.controlAdapter.type === 't2i_adapter');
+export const addT2IAdapters = async ({
+ manager,
+ entities,
+ g,
+ rect,
+ collector,
+ model,
+}: AddT2IAdaptersArg): Promise => {
+ const validControlLayers = entities
+ .filter((entity) => entity.isEnabled)
+ .filter((entity) => entity.controlAdapter.type === 't2i_adapter')
+ .filter((entity) => getControlLayerWarnings(entity, model).length === 0);
const result: AddT2IAdaptersResult = {
addedT2IAdapters: 0,
@@ -145,11 +160,3 @@ const addT2IAdapterToGraph = (
g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item');
};
-
-const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => {
- // Must be have a model
- const hasModel = Boolean(controlAdapter.model);
- // Model must match the current base model
- const modelMatchesBase = controlAdapter.model?.base === base;
- return hasModel && modelMatchesBase;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts
index 81b98f3ef5..0a3a43a018 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts
@@ -1,19 +1,23 @@
import type { CanvasReferenceImageState } from 'features/controlLayers/store/types';
+import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
-import type { BaseModelType, Invocation } from 'services/api/types';
+import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
+import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
type AddIPAdaptersResult = {
addedIPAdapters: number;
};
-export const addIPAdapters = (
- ipAdapters: CanvasReferenceImageState[],
- g: Graph,
- collector: Invocation<'collect'>,
- base: BaseModelType
-): AddIPAdaptersResult => {
- const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base));
+type AddIPAdaptersArg = {
+ entities: CanvasReferenceImageState[];
+ g: Graph;
+ collector: Invocation<'collect'>;
+ model: ParameterModel;
+};
+
+export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => {
+ const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0);
const result: AddIPAdaptersResult = {
addedIPAdapters: 0,
@@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In
g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item');
};
-
-const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => {
- // Must be have a model that matches the current base and must have a control image
- const hasModel = Boolean(ipAdapter.model);
- const modelMatchesBase = ipAdapter.model?.base === base;
- const hasImage = Boolean(ipAdapter.image);
- return isEnabled && hasModel && modelMatchesBase && hasImage;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
index dcce2046da..a1921200d5 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts
@@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone';
import { withResultAsync } from 'common/util/result';
import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager';
import { getPrefixedId } from 'features/controlLayers/konva/util';
-import type {
- CanvasRegionalGuidanceState,
- IPAdapterConfig,
- Rect,
- RegionalGuidanceReferenceImageState,
-} from 'features/controlLayers/store/types';
+import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types';
+import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators';
import type { Graph } from 'features/nodes/util/graph/generation/Graph';
+import type { ParameterModel } from 'features/parameters/types/parameterSchemas';
import { serializeError } from 'serialize-error';
-import type { BaseModelType, Invocation } from 'services/api/types';
+import type { Invocation } from 'services/api/types';
import { assert } from 'tsafe';
const log = logger('system');
@@ -23,19 +20,26 @@ type AddedRegionResult = {
addedIPAdapters: number;
};
-const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => {
- const isEnabled = rg.isEnabled;
- const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt);
- const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0;
- return isEnabled && (hasTextPrompt || hasIPAdapter);
+type AddRegionsArg = {
+ manager: CanvasManager;
+ regions: CanvasRegionalGuidanceState[];
+ g: Graph;
+ bbox: Rect;
+ model: ParameterModel;
+ posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
+ negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null;
+ posCondCollect: Invocation<'collect'>;
+ negCondCollect: Invocation<'collect'> | null;
+ ipAdapterCollect: Invocation<'collect'>;
};
/**
* Adds regional guidance to the graph
+ * @param manager The canvas manager
* @param regions Array of regions to add
* @param g The graph to add the layers to
- * @param base The base model type
- * @param denoise The main denoise node
+ * @param bbox The bounding box
+ * @param model The main model
* @param posCond The positive conditioning node
* @param negCond The negative conditioning node
* @param posCondCollect The positive conditioning collector
@@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) =>
* @returns A promise that resolves to the regions that were successfully added to the graph
*/
-export const addRegions = async (
- manager: CanvasManager,
- regions: CanvasRegionalGuidanceState[],
- g: Graph,
- bbox: Rect,
- base: BaseModelType,
- denoise: Invocation<'denoise_latents'>,
- posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
- negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>,
- posCondCollect: Invocation<'collect'>,
- negCondCollect: Invocation<'collect'>,
- ipAdapterCollect: Invocation<'collect'>
-): Promise => {
- const isSDXL = base === 'sdxl';
+export const addRegions = async ({
+ manager,
+ regions,
+ g,
+ bbox,
+ model,
+ posCond,
+ negCond,
+ posCondCollect,
+ negCondCollect,
+ ipAdapterCollect,
+}: AddRegionsArg): Promise => {
+ const isSDXL = model.base === 'sdxl';
+ const isFLUX = model.base === 'flux';
+
+ const validRegions = regions.filter((rg) => {
+ if (!rg.isEnabled) {
+ return false;
+ }
+ return getRegionalGuidanceWarnings(rg, model).length === 0;
+ });
- const validRegions = regions.filter((rg) => isValidRegion(rg, base));
const results: AddedRegionResult[] = [];
for (const region of validRegions) {
@@ -94,20 +104,27 @@ export const addRegions = async (
if (region.positivePrompt) {
// The main positive conditioning node
result.addedPositivePrompt = true;
- const regionalPosCond = g.addNode(
- isSDXL
- ? {
- type: 'sdxl_compel_prompt',
- id: getPrefixedId('prompt_region_positive_cond'),
- prompt: region.positivePrompt,
- style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
- }
- : {
- type: 'compel',
- id: getPrefixedId('prompt_region_positive_cond'),
- prompt: region.positivePrompt,
- }
- );
+ let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
+ if (isSDXL) {
+ regionalPosCond = g.addNode({
+ type: 'sdxl_compel_prompt',
+ id: getPrefixedId('prompt_region_positive_cond'),
+ prompt: region.positivePrompt,
+ style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields?
+ });
+ } else if (isFLUX) {
+ regionalPosCond = g.addNode({
+ type: 'flux_text_encoder',
+ id: getPrefixedId('prompt_region_positive_cond'),
+ prompt: region.positivePrompt,
+ });
+ } else {
+ regionalPosCond = g.addNode({
+ type: 'compel',
+ id: getPrefixedId('prompt_region_positive_cond'),
+ prompt: region.positivePrompt,
+ });
+ }
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask');
// Connect the conditioning to the collector
@@ -115,38 +132,55 @@ export const addRegions = async (
// Copy the connections to the "global" positive conditioning node to the regional cond
if (posCond.type === 'compel') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) {
- // Clone the edge, but change the destination node to the regional conditioning node
+ const clone = deepClone(edge);
+ clone.destination.node_id = regionalPosCond.id;
+ g.addEdgeFromObj(clone);
+ }
+ } else if (posCond.type === 'sdxl_compel_prompt') {
+ for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
+ const clone = deepClone(edge);
+ clone.destination.node_id = regionalPosCond.id;
+ g.addEdgeFromObj(clone);
+ }
+ } else if (posCond.type === 'flux_text_encoder') {
+ for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCond.id;
g.addEdgeFromObj(clone);
}
} else {
- for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
- // Clone the edge, but change the destination node to the regional conditioning node
- const clone = deepClone(edge);
- clone.destination.node_id = regionalPosCond.id;
- g.addEdgeFromObj(clone);
- }
+ assert(false, 'Unsupported positive conditioning node type.');
}
}
if (region.negativePrompt) {
- result.addedNegativePrompt = true;
+ assert(negCond, 'Negative conditioning node is required if there is a negative prompt');
+ assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt');
+
// The main negative conditioning node
- const regionalNegCond = g.addNode(
- isSDXL
- ? {
- type: 'sdxl_compel_prompt',
- id: getPrefixedId('prompt_region_negative_cond'),
- prompt: region.negativePrompt,
- style: region.negativePrompt,
- }
- : {
- type: 'compel',
- id: getPrefixedId('prompt_region_negative_cond'),
- prompt: region.negativePrompt,
- }
- );
+ result.addedNegativePrompt = true;
+ let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
+ if (isSDXL) {
+ regionalNegCond = g.addNode({
+ type: 'sdxl_compel_prompt',
+ id: getPrefixedId('prompt_region_negative_cond'),
+ prompt: region.negativePrompt,
+ style: region.negativePrompt,
+ });
+ } else if (isFLUX) {
+ regionalNegCond = g.addNode({
+ type: 'flux_text_encoder',
+ id: getPrefixedId('prompt_region_negative_cond'),
+ prompt: region.negativePrompt,
+ });
+ } else {
+ regionalNegCond = g.addNode({
+ type: 'compel',
+ id: getPrefixedId('prompt_region_negative_cond'),
+ prompt: region.negativePrompt,
+ });
+ }
+
// Connect the mask to the conditioning
g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask');
// Connect the conditioning to the collector
@@ -158,17 +192,27 @@ export const addRegions = async (
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
- } else {
+ } else if (negCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalNegCond.id;
g.addEdgeFromObj(clone);
}
+ } else if (negCond.type === 'flux_text_encoder') {
+ for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
+ const clone = deepClone(edge);
+ clone.destination.node_id = regionalNegCond.id;
+ g.addEdgeFromObj(clone);
+ }
+ } else {
+ assert(false, 'Unsupported negative conditioning node type.');
}
}
// If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node
if (region.autoNegative && region.positivePrompt) {
+ assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting');
+
result.addedAutoNegativePositivePrompt = true;
// We re-use the mask image, but invert it when converting to tensor
const invertTensorMask = g.addNode({
@@ -178,20 +222,27 @@ export const addRegions = async (
// Connect the OG mask image to the inverted mask-to-tensor node
g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask');
// Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt
- const regionalPosCondInverted = g.addNode(
- isSDXL
- ? {
- type: 'sdxl_compel_prompt',
- id: getPrefixedId('prompt_region_positive_cond_inverted'),
- prompt: region.positivePrompt,
- style: region.positivePrompt,
- }
- : {
- type: 'compel',
- id: getPrefixedId('prompt_region_positive_cond_inverted'),
- prompt: region.positivePrompt,
- }
- );
+ let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>;
+ if (isSDXL) {
+ regionalPosCondInverted = g.addNode({
+ type: 'sdxl_compel_prompt',
+ id: getPrefixedId('prompt_region_positive_cond_inverted'),
+ prompt: region.positivePrompt,
+ style: region.positivePrompt,
+ });
+ } else if (isFLUX) {
+ regionalPosCondInverted = g.addNode({
+ type: 'flux_text_encoder',
+ id: getPrefixedId('prompt_region_positive_cond_inverted'),
+ prompt: region.positivePrompt,
+ });
+ } else {
+ regionalPosCondInverted = g.addNode({
+ type: 'compel',
+ id: getPrefixedId('prompt_region_positive_cond_inverted'),
+ prompt: region.positivePrompt,
+ });
+ }
// Connect the inverted mask to the conditioning
g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask');
// Connect the conditioning to the negative collector
@@ -203,20 +254,26 @@ export const addRegions = async (
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
- } else {
+ } else if (posCond.type === 'sdxl_compel_prompt') {
for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) {
const clone = deepClone(edge);
clone.destination.node_id = regionalPosCondInverted.id;
g.addEdgeFromObj(clone);
}
+ } else if (posCond.type === 'flux_text_encoder') {
+ for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) {
+ const clone = deepClone(edge);
+ clone.destination.node_id = regionalPosCondInverted.id;
+ g.addEdgeFromObj(clone);
+ }
+ } else {
+ assert(false, 'Unsupported positive conditioning node type.');
}
}
- const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) =>
- isValidIPAdapter(ipAdapter, base)
- );
+ for (const { id, ipAdapter } of region.referenceImages) {
+ assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.');
- for (const { id, ipAdapter } of validRGIPAdapters) {
result.addedIPAdapters++;
const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter;
assert(model, 'IP Adapter model is required');
@@ -248,11 +305,3 @@ export const addRegions = async (
return results;
};
-
-const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => {
- // Must be have a model that matches the current base and must have a control image
- const hasModel = Boolean(ipAdapter.model);
- const modelMatchesBase = ipAdapter.model?.base === base;
- const hasImage = Boolean(ipAdapter.image);
- return hasModel && modelMatchesBase && hasImage;
-};
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
index d893760f3c..885954e995 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts
@@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo
import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint';
import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker';
import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint';
+import { addRegions } from 'features/nodes/util/graph/generation/addRegions';
import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage';
import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker';
import { Graph } from 'features/nodes/util/graph/generation/Graph';
@@ -79,7 +80,10 @@ export const buildFLUXGraph = async (
id: getPrefixedId('flux_text_encoder'),
prompt: positivePrompt,
});
-
+ const posCondCollect = g.addNode({
+ type: 'collect',
+ id: getPrefixedId('pos_cond_collect'),
+ });
const denoise = g.addNode({
type: 'flux_denoise',
id: getPrefixedId('flux_denoise'),
@@ -104,13 +108,12 @@ export const buildFLUXGraph = async (
g.addEdge(modelLoader, 'clip', posCond, 'clip');
g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder');
g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len');
+ g.addEdge(posCond, 'conditioning', posCondCollect, 'item');
+ g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning');
+ g.addEdge(denoise, 'latents', l2i, 'latents');
addFLUXLoRAs(state, g, denoise, modelLoader, posCond);
- g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning');
-
- g.addEdge(denoise, 'latents', l2i, 'latents');
-
const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig);
assert(modelConfig.base === 'flux');
@@ -196,31 +199,50 @@ export const buildFLUXGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
- const controlNetResult = await addControlNets(
+ const controlNetResult = await addControlNets({
manager,
- canvas.controlLayers.entities,
+ entities: canvas.controlLayers.entities,
g,
- canvas.bbox.rect,
- controlNetCollector,
- modelConfig.base
- );
+ rect: canvas.bbox.rect,
+ collector: controlNetCollector,
+ model: modelConfig,
+ });
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
- const ipAdapterCollector = g.addNode({
+ const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
- const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
+ const ipAdapterResult = addIPAdapters({
+ entities: canvas.referenceImages.entities,
+ g,
+ collector: ipAdapterCollect,
+ model: modelConfig,
+ });
- const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters;
+ const regionsResult = await addRegions({
+ manager,
+ regions: canvas.regionalGuidance.entities,
+ g,
+ bbox: canvas.bbox.rect,
+ model: modelConfig,
+ posCond,
+ negCond: null,
+ posCondCollect,
+ negCondCollect: null,
+ ipAdapterCollect,
+ });
+
+ const totalIPAdaptersAdded =
+ ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
- g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
+ g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
- g.deleteNode(ipAdapterCollector.id);
+ g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts
index 4008fd05d6..7522227007 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts
@@ -227,14 +227,14 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
- const controlNetResult = await addControlNets(
+ const controlNetResult = await addControlNets({
manager,
- canvas.controlLayers.entities,
+ entities: canvas.controlLayers.entities,
g,
- canvas.bbox.rect,
- controlNetCollector,
- modelConfig.base
- );
+ rect: canvas.bbox.rect,
+ collector: controlNetCollector,
+ model: modelConfig,
+ });
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -245,46 +245,50 @@ export const buildSD1Graph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
- const t2iAdapterResult = await addT2IAdapters(
+ const t2iAdapterResult = await addT2IAdapters({
manager,
- canvas.controlLayers.entities,
+ entities: canvas.controlLayers.entities,
g,
- canvas.bbox.rect,
- t2iAdapterCollector,
- modelConfig.base
- );
+ rect: canvas.bbox.rect,
+ collector: t2iAdapterCollector,
+ model: modelConfig,
+ });
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
- const ipAdapterCollector = g.addNode({
+ const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
- const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
-
- const regionsResult = await addRegions(
- manager,
- canvas.regionalGuidance.entities,
+ const ipAdapterResult = addIPAdapters({
+ entities: canvas.referenceImages.entities,
g,
- canvas.bbox.rect,
- modelConfig.base,
- denoise,
+ collector: ipAdapterCollect,
+ model: modelConfig,
+ });
+
+ const regionsResult = await addRegions({
+ manager,
+ regions: canvas.regionalGuidance.entities,
+ g,
+ bbox: canvas.bbox.rect,
+ model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
- ipAdapterCollector
- );
+ ipAdapterCollect,
+ });
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
- g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
+ g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
- g.deleteNode(ipAdapterCollector.id);
+ g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {
diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
index 37ab697522..9357a291b4 100644
--- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
+++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts
@@ -232,14 +232,14 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
- const controlNetResult = await addControlNets(
+ const controlNetResult = await addControlNets({
manager,
- canvas.controlLayers.entities,
+ entities: canvas.controlLayers.entities,
g,
- canvas.bbox.rect,
- controlNetCollector,
- modelConfig.base
- );
+ rect: canvas.bbox.rect,
+ collector: controlNetCollector,
+ model: modelConfig,
+ });
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', denoise, 'control');
} else {
@@ -250,46 +250,50 @@ export const buildSDXLGraph = async (
type: 'collect',
id: getPrefixedId('t2i_adapter_collector'),
});
- const t2iAdapterResult = await addT2IAdapters(
+ const t2iAdapterResult = await addT2IAdapters({
manager,
- canvas.controlLayers.entities,
+ entities: canvas.controlLayers.entities,
g,
- canvas.bbox.rect,
- t2iAdapterCollector,
- modelConfig.base
- );
+ rect: canvas.bbox.rect,
+ collector: t2iAdapterCollector,
+ model: modelConfig,
+ });
if (t2iAdapterResult.addedT2IAdapters > 0) {
g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter');
} else {
g.deleteNode(t2iAdapterCollector.id);
}
- const ipAdapterCollector = g.addNode({
+ const ipAdapterCollect = g.addNode({
type: 'collect',
id: getPrefixedId('ip_adapter_collector'),
});
- const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base);
-
- const regionsResult = await addRegions(
- manager,
- canvas.regionalGuidance.entities,
+ const ipAdapterResult = addIPAdapters({
+ entities: canvas.referenceImages.entities,
g,
- canvas.bbox.rect,
- modelConfig.base,
- denoise,
+ collector: ipAdapterCollect,
+ model: modelConfig,
+ });
+
+ const regionsResult = await addRegions({
+ manager,
+ regions: canvas.regionalGuidance.entities,
+ g,
+ bbox: canvas.bbox.rect,
+ model: modelConfig,
posCond,
negCond,
posCondCollect,
negCondCollect,
- ipAdapterCollector
- );
+ ipAdapterCollect,
+ });
const totalIPAdaptersAdded =
ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0);
if (totalIPAdaptersAdded > 0) {
- g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter');
+ g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter');
} else {
- g.deleteNode(ipAdapterCollector.id);
+ g.deleteNode(ipAdapterCollect.id);
}
if (state.system.shouldUseNSFWChecker) {
diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts
index 0d14bba73d..0af16bdf78 100644
--- a/invokeai/frontend/web/src/features/queue/store/readiness.ts
+++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts
@@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice';
import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice } from 'features/controlLayers/store/selectors';
import type { CanvasState } from 'features/controlLayers/store/types';
+import {
+ getControlLayerWarnings,
+ getGlobalReferenceImageWarnings,
+ getInpaintMaskWarnings,
+ getRasterLayerWarnings,
+ getRegionalGuidanceWarnings,
+} from 'features/controlLayers/store/validators';
import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice';
import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt';
@@ -278,17 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
- const problems: string[] = [];
- // Must have model
- if (!controlLayer.controlAdapter.model) {
- problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected'));
- }
- // Model base must match
- if (controlLayer.controlAdapter.model?.base !== model?.base) {
- problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'));
- }
+ const problems = getControlLayerWarnings(controlLayer, model);
+
if (problems.length) {
- const content = upperFirst(problems.join(', '));
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -300,23 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
- const problems: string[] = [];
-
- // Must have model
- if (!entity.ipAdapter.model) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
- }
- // Model base must match
- if (entity.ipAdapter.model?.base !== model?.base) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
- }
- // Must have an image
- if (!entity.ipAdapter.image) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
- }
+ const problems = getGlobalReferenceImageWarnings(entity, model);
if (problems.length) {
- const content = upperFirst(problems.join(', '));
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -328,32 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
- const problems: string[] = [];
- // Must have a region
- if (entity.objects.length === 0) {
- problems.push(i18n.t('parameters.invoke.layer.rgNoRegion'));
- }
- // Must have at least 1 prompt or IP Adapter
- if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) {
- problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters'));
- }
- entity.referenceImages.forEach(({ ipAdapter }) => {
- // Must have model
- if (!ipAdapter.model) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected'));
- }
- // Model base must match
- if (ipAdapter.model?.base !== model?.base) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'));
- }
- // Must have an image
- if (!ipAdapter.image) {
- problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected'));
- }
- });
+ const problems = getRegionalGuidanceWarnings(entity, model);
if (problems.length) {
- const content = upperFirst(problems.join(', '));
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
@@ -365,10 +330,25 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: {
const layerNumber = i + 1;
const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
- const problems: string[] = [];
+ const problems = getRasterLayerWarnings(entity, model);
if (problems.length) {
- const content = upperFirst(problems.join(', '));
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
+ reasons.push({ prefix, content });
+ }
+ });
+
+ canvas.inpaintMasks.entities
+ .filter((entity) => entity.isEnabled)
+ .forEach((entity, i) => {
+ const layerLiteral = i18n.t('controlLayers.layer_one');
+ const layerNumber = i + 1;
+ const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]);
+ const prefix = `${layerLiteral} #${layerNumber} (${layerType})`;
+ const problems = getInpaintMaskWarnings(entity, model);
+
+ if (problems.length) {
+ const content = upperFirst(problems.map((p) => i18n.t(p)).join(', '));
reasons.push({ prefix, content });
}
});
diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts
index 90adf7517e..8878a4bcfa 100644
--- a/invokeai/frontend/web/src/services/api/schema.ts
+++ b/invokeai/frontend/web/src/services/api/schema.ts
@@ -6564,6 +6564,11 @@ export type components = {
* @description The name of conditioning tensor
*/
conditioning_name: string;
+ /**
+ * @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True.
+ * @default null
+ */
+ mask?: components["schemas"]["TensorField"] | null;
};
/**
* FluxConditioningOutput
@@ -6771,15 +6776,17 @@ export type components = {
*/
transformer?: components["schemas"]["TransformerField"];
/**
+ * Positive Text Conditioning
* @description Positive conditioning tensor
* @default null
*/
- positive_text_conditioning?: components["schemas"]["FluxConditioningField"];
+ positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][];
/**
+ * Negative Text Conditioning
* @description Negative conditioning tensor. Can be None if cfg_scale is 1.0.
* @default null
*/
- negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null;
+ negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null;
/**
* CFG Scale
* @description Classifier-Free Guidance scale
@@ -7133,6 +7140,11 @@ export type components = {
* @default null
*/
prompt?: string;
+ /**
+ * @description A mask defining the region that this conditioning prompt applies to.
+ * @default null
+ */
+ mask?: components["schemas"]["TensorField"] | null;
/**
* type
* @default flux_text_encoder