diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 5e76931933..16bfcaeed7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") + mask: Optional[TensorField] = Field( + default=None, + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to True.", + ) class SD3ConditioningField(BaseModel): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9e197626b5..f2a2dd586a 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -30,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux @@ -42,6 +43,7 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) +from invokeai.backend.flux.text_conditioning import FluxTextConditioning from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher @@ -87,10 +89,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, title="Transformer", ) - positive_text_conditioning: FluxConditioningField = InputField( + positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection ) - negative_text_conditioning: FluxConditioningField | None = InputField( + negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField( default=None, description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.", input=Input.Connection, @@ -139,36 +141,12 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) - def _load_text_conditioning( - self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Load the conditioning data. - cond_data = context.conditioning.load(conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) - t5_embeddings = flux_conditioning.t5_embeds - clip_embeddings = flux_conditioning.clip_embeds - return t5_embeddings, clip_embeddings - def _run_diffusion( self, context: InvocationContext, ): inference_dtype = torch.bfloat16 - # Load the conditioning data. - pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning( - context, self.positive_text_conditioning.conditioning_name, inference_dtype - ) - neg_t5_embeddings: torch.Tensor | None = None - neg_clip_embeddings: torch.Tensor | None = None - if self.negative_text_conditioning is not None: - neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning( - context, self.negative_text_conditioning.conditioning_name, inference_dtype - ) - # Load the input latents, if provided. init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None if init_latents is not None: @@ -183,15 +161,45 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, seed=self.seed, ) + b, _c, latent_h, latent_w = noise.shape + packed_h = latent_h // 2 + packed_w = latent_w // 2 + + # Load the conditioning data. + pos_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.positive_text_conditioning, + packed_height=packed_h, + packed_width=packed_w, + dtype=inference_dtype, + device=TorchDevice.choose_torch_device(), + ) + neg_text_conditionings: list[FluxTextConditioning] | None = None + if self.negative_text_conditioning is not None: + neg_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.negative_text_conditioning, + packed_height=packed_h, + packed_width=packed_w, + dtype=inference_dtype, + device=TorchDevice.choose_torch_device(), + ) + pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning( + pos_text_conditionings, img_seq_len=packed_h * packed_w + ) + neg_regional_prompting_extension = ( + RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w) + if neg_text_conditionings + else None + ) transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path # Calculate the timestep schedule. - image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=image_seq_len, + image_seq_len=packed_h * packed_w, shift=not is_schnell, ) @@ -228,28 +236,17 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): inpaint_mask = self._prep_inpaint_mask(context, x) - b, _c, latent_h, latent_w = x.shape img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype) - pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape - pos_txt_ids = torch.zeros( - pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - neg_txt_ids: torch.Tensor | None = None - if neg_t5_embeddings is not None: - neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape - neg_txt_ids = torch.zeros( - neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - # Pack all latent tensors. init_latents = pack(init_latents) if init_latents is not None else None inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None noise = pack(noise) x = pack(x) - # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly. - assert image_seq_len == x.shape[1] + # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and + # packed_w correctly. + assert packed_h * packed_w == x.shape[1] # Prepare inpaint extension. inpaint_extension: InpaintExtension | None = None @@ -338,12 +335,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): model=transformer, img=x, img_ids=img_ids, - txt=pos_t5_embeddings, - txt_ids=pos_txt_ids, - vec=pos_clip_embeddings, - neg_txt=neg_t5_embeddings, - neg_txt_ids=neg_txt_ids, - neg_vec=neg_clip_embeddings, + pos_regional_prompting_extension=pos_regional_prompting_extension, + neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, @@ -357,6 +350,43 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): x = unpack(x.float(), self.height, self.width) return x + def _load_text_conditioning( + self, + context: InvocationContext, + cond_field: FluxConditioningField | list[FluxConditioningField], + packed_height: int, + packed_width: int, + dtype: torch.dtype, + device: torch.device, + ) -> list[FluxTextConditioning]: + """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" + # Normalize to a list of FluxConditioningFields. + cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field + + text_conditionings: list[FluxTextConditioning] = [] + for cond_field in cond_list: + # Load the text embeddings. + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=dtype, device=device) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + # Load the mask, if provided. + mask: Optional[torch.Tensor] = None + if cond_field.mask is not None: + mask = context.tensors.load(cond_field.mask.tensor_name) + mask = mask.to(device=device) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask( + mask, packed_height, packed_width, dtype, device + ) + + text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) + + return text_conditionings + @classmethod def prep_cfg_scale( cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 91c89cb31b..1eb0fea62e 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,11 +1,18 @@ from contextlib import ExitStack -from typing import Iterator, Literal, Tuple +from typing import Iterator, Literal, Optional, Tuple import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, UIComponent +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FluxConditioningField, + Input, + InputField, + TensorField, + UIComponent, +) from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -41,9 +48,9 @@ class FluxTextEncoderInvocation(BaseInvocation): t5_max_seq_len: Literal[256, 512] = InputField( description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." ) - prompt: str = InputField( - description="Text prompt to encode.", - ui_component=UIComponent.Textarea, + prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea) + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." ) @torch.no_grad() @@ -57,7 +64,9 @@ class FluxTextEncoderInvocation(BaseInvocation): ) conditioning_name = context.conditioning.save(conditioning_data) - return FluxConditioningOutput.build(conditioning_name) + return FluxConditioningOutput( + conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask) + ) def _t5_encode(self, context: InvocationContext) -> torch.Tensor: t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index e0c7779e93..0f56adacde 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -1,9 +1,10 @@ import einops import torch +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.math import attention -from invokeai.backend.flux.modules.layers import DoubleStreamBlock +from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock class CustomDoubleStreamBlockProcessor: @@ -13,7 +14,12 @@ class CustomDoubleStreamBlockProcessor: @staticmethod def _double_stream_block_forward( - block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor + block: DoubleStreamBlock, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + attn_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate values. @@ -40,7 +46,7 @@ class CustomDoubleStreamBlockProcessor: k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks @@ -63,11 +69,15 @@ class CustomDoubleStreamBlockProcessor: vec: torch.Tensor, pe: torch.Tensor, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> tuple[torch.Tensor, torch.Tensor]: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support """ - img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe) + attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index) + img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward( + block, img, txt, vec, pe, attn_mask=attn_mask + ) # Apply IP-Adapter conditioning. for ip_adapter_extension in ip_adapter_extensions: @@ -81,3 +91,48 @@ class CustomDoubleStreamBlockProcessor: ) return img, txt + + +class CustomSingleStreamBlockProcessor: + """A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking, + etc.) + """ + + @staticmethod + def _single_stream_block_forward( + block: SingleStreamBlock, + x: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """This function is a direct copy of SingleStreamBlock.forward().""" + mod, _ = block.modulation(vec) + x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift + qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1) + + q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads) + q, k = block.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer + output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2)) + return x + mod.gate * output + + @staticmethod + def custom_single_block_forward( + timestep_index: int, + total_num_timesteps: int, + block_index: int, + block: SingleStreamBlock, + img: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + regional_prompting_extension: RegionalPromptingExtension, + ) -> torch.Tensor: + """A custom implementation of SingleStreamBlock.forward() with additional features: + - Masking + """ + attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index) + return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index bb0e60409a..66e3984edb 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -7,6 +7,7 @@ from tqdm import tqdm from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux @@ -18,14 +19,8 @@ def denoise( # model input img: torch.Tensor, img_ids: torch.Tensor, - # positive text conditioning - txt: torch.Tensor, - txt_ids: torch.Tensor, - vec: torch.Tensor, - # negative text conditioning - neg_txt: torch.Tensor | None, - neg_txt_ids: torch.Tensor | None, - neg_vec: torch.Tensor | None, + pos_regional_prompting_extension: RegionalPromptingExtension, + neg_regional_prompting_extension: RegionalPromptingExtension | None, # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], @@ -61,9 +56,9 @@ def denoise( total_num_timesteps=total_steps, img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, ) @@ -78,9 +73,9 @@ def denoise( pred = model( img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, @@ -88,6 +83,7 @@ def denoise( controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ip_adapter_extensions=pos_ip_adapter_extensions, + regional_prompting_extension=pos_regional_prompting_extension, ) step_cfg_scale = cfg_scale[step_index] @@ -97,15 +93,15 @@ def denoise( # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance # on systems with sufficient VRAM. - if neg_txt is None or neg_txt_ids is None or neg_vec is None: + if neg_regional_prompting_extension is None: raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") neg_pred = model( img=img, img_ids=img_ids, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, @@ -113,6 +109,7 @@ def denoise( controlnet_double_block_residuals=None, controlnet_single_block_residuals=None, ip_adapter_extensions=neg_ip_adapter_extensions, + regional_prompting_extension=neg_regional_prompting_extension, ) pred = neg_pred + step_cfg_scale * (pred - neg_pred) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py new file mode 100644 index 0000000000..f1086c3286 --- /dev/null +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -0,0 +1,276 @@ +from typing import Optional + +import torch +import torchvision + +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask + + +class RegionalPromptingExtension: + """A class for managing regional prompting with FLUX. + + This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences). + """ + + def __init__( + self, + regional_text_conditioning: FluxRegionalTextConditioning, + restricted_attn_mask: torch.Tensor | None = None, + ): + self.regional_text_conditioning = regional_text_conditioning + self.restricted_attn_mask = restricted_attn_mask + + def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] + + def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] + + @classmethod + def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int): + """Create a RegionalPromptingExtension from a list of text conditionings. + + Args: + text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting. + img_seq_len (int): The image sequence length (i.e. packed_height * packed_width). + """ + regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning) + attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask( + regional_text_conditioning, img_seq_len + ) + return cls( + regional_text_conditioning=regional_text_conditioning, + restricted_attn_mask=attn_mask_with_restricted_img_self_attn, + ) + + # Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy: + # + # @classmethod + # def _prepare_unrestricted_attn_mask( + # cls, + # regional_text_conditioning: FluxRegionalTextConditioning, + # img_seq_len: int, + # ) -> torch.Tensor: + # """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: + # - img self-attention is not masked. + # - img regions attend to both txt within their own region and to global prompts. + # """ + # device = TorchDevice.choose_torch_device() + + # # Infer txt_seq_len from the t5_embeddings tensor. + # txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # # Concatenation happens in the following order: [txt_seq, img_seq]. + # # There are 4 portions of the attention mask to consider as we prepare it: + # # 1. txt attends to itself + # # 2. txt attends to corresponding regional img + # # 3. regional img attends to corresponding txt + # # 4. regional img attends to itself + + # # Initialize empty attention mask. + # regional_attention_mask = torch.zeros( + # (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + # ) + + # for image_mask, t5_embedding_range in zip( + # regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + # ): + # # 1. txt attends to itself + # regional_attention_mask[ + # t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + # ] = 1.0 + + # # 2. txt attends to corresponding regional img + # # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + # fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 + # regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value + + # # 3. regional img attends to corresponding txt + # # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + # fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 + # regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value + + # # 4. regional img attends to itself + # # Allow unrestricted img self attention. + # regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + + # # Convert attention mask to boolean. + # regional_attention_mask = regional_attention_mask > 0.5 + + # return regional_attention_mask + + @classmethod + def _prepare_restricted_attn_mask( + cls, + regional_text_conditioning: FluxRegionalTextConditioning, + img_seq_len: int, + ) -> torch.Tensor | None: + """Prepare a 'restricted' attention mask. In this context, 'restricted' means that: + - img self-attention is only allowed within regions. + - img regions only attend to txt within their own region, not to global prompts. + """ + # Identify background region. I.e. the region that is not covered by any region masks. + background_region_mask: None | torch.Tensor = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + if background_region_mask is None: + background_region_mask = torch.ones_like(image_mask) + background_region_mask *= 1 - image_mask + + if background_region_mask is None: + # There are no region masks, short-circuit and return None. + # TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would + # is a rare use case and would make the logic here significantly more complicated. + return None + + device = TorchDevice.choose_torch_device() + + # Infer txt_seq_len from the t5_embeddings tensor. + txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # Concatenation happens in the following order: [txt_seq, img_seq]. + # There are 4 portions of the attention mask to consider as we prepare it: + # 1. txt attends to itself + # 2. txt attends to corresponding regional img + # 3. regional img attends to corresponding txt + # 4. regional img attends to itself + + # Initialize empty attention mask. + regional_attention_mask = torch.zeros( + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + ) + + for image_mask, t5_embedding_range in zip( + regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + ): + # 1. txt attends to itself + regional_attention_mask[ + t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + ] = 1.0 + + if image_mask is not None: + # 2. txt attends to corresponding regional img + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + image_mask.view(1, img_seq_len) + ) + + # 3. regional img attends to corresponding txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + image_mask.view(img_seq_len, 1) + ) + + # 4. regional img attends to itself + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + else: + # We don't allow attention between non-background image regions and global prompts. This helps to ensure + # that regions focus on their local prompts. We do, however, allow attention between background regions + # and global prompts. If we didn't do this, then the background regions would not attend to any txt + # embeddings, which we found experimentally to cause artifacts. + + # 2. global txt attends to background region + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + background_region_mask.view(1, img_seq_len) + ) + + # 3. background region attends to global txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + background_region_mask.view(img_seq_len, 1) + ) + + # Allow background regions to attend to themselves. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + + # Convert attention mask to boolean. + regional_attention_mask = regional_attention_mask > 0.5 + + return regional_attention_mask + + @classmethod + def _concat_regional_text_conditioning( + cls, + text_conditionings: list[FluxTextConditioning], + ) -> FluxRegionalTextConditioning: + """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" + concat_t5_embeddings: list[torch.Tensor] = [] + concat_t5_embedding_ranges: list[Range] = [] + image_masks: list[torch.Tensor | None] = [] + + # Choose global CLIP embedding. + # Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use + # the first prompt's CLIP embedding. + global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings + for text_conditioning in text_conditionings: + if text_conditioning.mask is None: + global_clip_embedding = text_conditioning.clip_embeddings + break + + cur_t5_embedding_len = 0 + for text_conditioning in text_conditionings: + concat_t5_embeddings.append(text_conditioning.t5_embeddings) + + concat_t5_embedding_ranges.append( + Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) + ) + + image_masks.append(text_conditioning.mask) + + cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] + + t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) + + # Initialize the txt_ids tensor. + pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape + t5_txt_ids = torch.zeros( + pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() + ) + + return FluxRegionalTextConditioning( + t5_embeddings=t5_embeddings, + clip_embeddings=global_clip_embedding, + t5_txt_ids=t5_txt_ids, + image_masks=image_masks, + t5_embedding_ranges=concat_t5_embedding_ranges, + ) + + @staticmethod + def preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + packed_height and packed_width are the target height and width of the mask in the 'packed' latent space. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width). + """ + + if mask is None: + return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + + # Flatten the height and width dimensions into a single image_seq_len dimension. + return resized_mask.flatten(start_dim=2) diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py index 0fac7a1d16..260243a35d 100644 --- a/invokeai/backend/flux/math.py +++ b/invokeai/backend/flux/math.py @@ -5,10 +5,10 @@ from einops import rearrange from torch import Tensor -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 0dadacd8fe..0add6fd4d7 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -5,7 +5,11 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor +from invokeai.backend.flux.custom_block_processor import ( + CustomDoubleStreamBlockProcessor, + CustomSingleStreamBlockProcessor, +) +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.modules.layers import ( DoubleStreamBlock, @@ -95,6 +99,7 @@ class Flux(nn.Module): controlnet_double_block_residuals: list[Tensor] | None, controlnet_single_block_residuals: list[Tensor] | None, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -117,7 +122,6 @@ class Flux(nn.Module): assert len(controlnet_double_block_residuals) == len(self.double_blocks) for block_index, block in enumerate(self.double_blocks): assert isinstance(block, DoubleStreamBlock) - img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward( timestep_index=timestep_index, total_num_timesteps=total_num_timesteps, @@ -128,6 +132,7 @@ class Flux(nn.Module): vec=vec, pe=pe, ip_adapter_extensions=ip_adapter_extensions, + regional_prompting_extension=regional_prompting_extension, ) if controlnet_double_block_residuals is not None: @@ -140,7 +145,17 @@ class Flux(nn.Module): assert len(controlnet_single_block_residuals) == len(self.single_blocks) for block_index, block in enumerate(self.single_blocks): - img = block(img, vec=vec, pe=pe) + assert isinstance(block, SingleStreamBlock) + img = CustomSingleStreamBlockProcessor.custom_single_block_forward( + timestep_index=timestep_index, + total_num_timesteps=total_num_timesteps, + block_index=block_index, + block=block, + img=img, + vec=vec, + pe=pe, + regional_prompting_extension=regional_prompting_extension, + ) if controlnet_single_block_residuals is not None: img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index] diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py new file mode 100644 index 0000000000..5bc8b4d041 --- /dev/null +++ b/invokeai/backend/flux/text_conditioning.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass + +import torch + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range + + +@dataclass +class FluxTextConditioning: + t5_embeddings: torch.Tensor + clip_embeddings: torch.Tensor + # If mask is None, the prompt is a global prompt. + mask: torch.Tensor | None + + +@dataclass +class FluxRegionalTextConditioning: + # Concatenated text embeddings. + # Shape: (1, concatenated_txt_seq_len, 4096) + t5_embeddings: torch.Tensor + # Shape: (1, concatenated_txt_seq_len, 3) + t5_txt_ids: torch.Tensor + + # Global CLIP embeddings. + # Shape: (1, 768) + clip_embeddings: torch.Tensor + + # A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a + # global prompt. + # image_masks[i] is the mask for the ith prompt. + # image_masks[i] has shape (1, image_seq_len) and dtype torch.bool. + image_masks: list[torch.Tensor | None] + + # List of ranges that represent the embedding ranges for each mask. + # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i]. + t5_embedding_ranges: list[Range] diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index d8957c0b1e..3b04bcff71 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -176,7 +176,8 @@ "reset": "Reset", "none": "None", "new": "New", - "generating": "Generating" + "generating": "Generating", + "warnings": "Warnings" }, "hrf": { "hrf": "High Resolution Fix", @@ -1040,6 +1041,7 @@ "noNodesInGraph": "No nodes in graph", "systemDisconnected": "System disconnected", "layer": { + "unsupportedModel": "layer not supported for selected base model", "controlAdapterNoModelSelected": "no Control Adapter model selected", "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}", @@ -1050,7 +1052,10 @@ "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", "rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters", - "rgNoRegion": "no region selected" + "rgNegativePromptNotSupported": "negative prompt not supported for selected base model", + "rgReferenceImagesNotSupported": "regional reference images not supported for selected base model", + "rgAutoNegativeNotSupported": "auto-negative not supported for selected base model", + "emptyLayer": "empty layer" } }, "maskBlur": "Mask Blur", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx index ed2bb86a88..c4462c0f4d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx @@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => { justifyContent="flex-start" leftIcon={} onClick={addRegionalGuidance} - isDisabled={isFLUX || isSD3} + isDisabled={isSD3} > {t('controlLayers.regionalGuidance')} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx index 70623f54b7..40c750bc52 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx @@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => { } onClick={addInpaintMask}> {t('controlLayers.inpaintMask')} - } onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}> + } onClick={addRegionalGuidance} isDisabled={isSD3}> {t('controlLayers.regionalGuidance')} } onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}> diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx index 2fc7483756..3e11bdc4e2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx @@ -1,27 +1,28 @@ -import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import type { IconButtonProps } from '@invoke-ai/ui-library'; +import { IconButton } from '@invoke-ai/ui-library'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiTrashSimpleFill } from 'react-icons/pi'; +import { PiXBold } from 'react-icons/pi'; -type Props = { +type Props = Omit & { onDelete: () => void; }; -export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => { +export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => { const { t } = useTranslation(); return ( - - } - onClick={onDelete} - flexGrow={0} - size="sm" - p={0} - colorScheme="error" - /> - + } + onClick={onDelete} + flexGrow={0} + size="sm" + p={0} + colorScheme="error" + {...rest} + /> ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx index c98d6bc23d..722720dfbf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx @@ -1,8 +1,10 @@ import { Button, Flex, Text } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; +import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice'; import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd'; import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd'; import { DndDropTarget } from 'features/dnd/DndDropTarget'; @@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag const onClickGalleryButton = useCallback(() => { dispatch(activeTabCanvasRightPanelChanged('gallery')); }, [dispatch]); + const onDeleteIPAdapter = useCallback(() => { + dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId })); + }, [dispatch, entityIdentifier, referenceImageId]); const dndTargetData = useMemo( () => @@ -43,6 +48,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag return ( + { return ( + {entityIdentifier.type !== 'reference_image' && } diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx new file mode 100644 index 0000000000..f13de15c19 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx @@ -0,0 +1,101 @@ +import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled'; +import { selectModel } from 'features/controlLayers/store/paramsSlice'; +import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors'; +import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; +import { + getControlLayerWarnings, + getGlobalReferenceImageWarnings, + getInpaintMaskWarnings, + getRasterLayerWarnings, + getRegionalGuidanceWarnings, +} from 'features/controlLayers/store/validators'; +import type { TFunction } from 'i18next'; +import { upperFirst } from 'lodash-es'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiWarningBold } from 'react-icons/pi'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; + +const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => { + return createSelector(selectCanvasSlice, selectModel, (canvas, model) => { + // This component is used within a so we can safely assume that the entity exists. + // Should never throw. + const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings'); + + let warnings: string[] = []; + + const entityType = entity.type; + + if (entityType === 'control_layer') { + warnings = getControlLayerWarnings(entity, model); + } else if (entityType === 'regional_guidance') { + warnings = getRegionalGuidanceWarnings(entity, model); + } else if (entityType === 'inpaint_mask') { + warnings = getInpaintMaskWarnings(entity, model); + } else if (entityType === 'raster_layer') { + warnings = getRasterLayerWarnings(entity, model); + } else if (entityType === 'reference_image') { + warnings = getGlobalReferenceImageWarnings(entity, model); + } else { + assert>(false, 'Unexpected entity type'); + } + + // Return a stable reference if there are no warnings + if (warnings.length === 0) { + return EMPTY_ARRAY; + } + + return warnings.map((w) => t(w)).map(upperFirst); + }); +}; + +export const CanvasEntityHeaderWarnings = memo(() => { + const entityIdentifier = useEntityIdentifierContext(); + const { t } = useTranslation(); + const isEnabled = useEntityIsEnabled(entityIdentifier); + const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]); + const warnings = useAppSelector(selectWarnings); + + if (warnings.length === 0) { + return null; + } + + return ( + // Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but + // it's not a button + } + icon={} + colorScheme="warning" + isDisabled={!isEnabled} + /> + ); +}); + +CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings'; + +const TooltipContent = memo((props: { warnings: string[] }) => { + const { t } = useTranslation(); + return ( + + {t('common.warnings')}: + + {props.warnings.map((warning, index) => ( + {warning} + ))} + + + ); +}); +TooltipContent.displayName = 'TooltipContent'; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts new file mode 100644 index 0000000000..0b8a34c3bd --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -0,0 +1,133 @@ +import type { + CanvasControlLayerState, + CanvasInpaintMaskState, + CanvasRasterLayerState, + CanvasReferenceImageState, + CanvasRegionalGuidanceState, +} from 'features/controlLayers/store/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; + +export const getRegionalGuidanceWarnings = ( + entity: CanvasRegionalGuidanceState, + model: ParameterModel | null +): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push('parameters.invoke.layer.emptyLayer'); + } else { + if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) { + // Must have at least 1 prompt or IP Adapter + warnings.push('parameters.invoke.layer.rgNoPromptsOrIPAdapters'); + } + + if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push('parameters.invoke.layer.unsupportedModel'); + } else if (model.base === 'flux') { + // Some features are not supported for flux models + if (entity.negativePrompt !== null) { + warnings.push('parameters.invoke.layer.rgNegativePromptNotSupported'); + } + if (entity.referenceImages.length > 0) { + warnings.push('parameters.invoke.layer.rgReferenceImagesNotSupported'); + } + if (entity.autoNegative) { + warnings.push('parameters.invoke.layer.rgAutoNegativeNotSupported'); + } + } else { + entity.referenceImages.forEach(({ ipAdapter }) => { + if (!ipAdapter.model) { + // No model selected + warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected'); + } else if (ipAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'); + } + + if (!ipAdapter.image) { + // No image selected + warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected'); + } + }); + } + } + } + + return warnings; +}; + +export const getGlobalReferenceImageWarnings = ( + entity: CanvasReferenceImageState, + model: ParameterModel | null +): string[] => { + const warnings: string[] = []; + + if (!entity.ipAdapter.model) { + // No model selected + warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected'); + } else if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push('parameters.invoke.layer.unsupportedModel'); + } else if (entity.ipAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'); + } + } + + if (!entity.ipAdapter.image) { + // No image selected + warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected'); + } + + return warnings; +}; + +export const getControlLayerWarnings = (entity: CanvasControlLayerState, model: ParameterModel | null): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push('parameters.invoke.layer.emptyLayer'); + } else { + if (!entity.controlAdapter.model) { + // No model selected + warnings.push('parameters.invoke.layer.controlAdapterNoModelSelected'); + } else if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push('parameters.invoke.layer.unsupportedModel'); + } else if (entity.controlAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'); + } + } + } + + return warnings; +}; + +export const getRasterLayerWarnings = (entity: CanvasRasterLayerState, _model: ParameterModel | null): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push('parameters.invoke.layer.emptyLayer'); + } + + return warnings; +}; + +export const getInpaintMaskWarnings = (entity: CanvasInpaintMaskState, _model: ParameterModel | null): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push('parameters.invoke.layer.emptyLayer'); + } + + return warnings; +}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index d012ca2853..a41dc06550 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -1,35 +1,41 @@ import { logger } from 'app/logging/logger'; import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import type { - CanvasControlLayerState, - ControlNetConfig, - Rect, - T2IAdapterConfig, -} from 'features/controlLayers/store/types'; +import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types'; +import { getControlLayerWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { serializeError } from 'serialize-error'; -import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; +import type { ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); +type AddControlNetsArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddControlNetsResult = { addedControlNets: number; }; -export const addControlNets = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 'controlnet'); +export const addControlNets = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddControlNetsArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 'controlnet') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddControlNetsResult = { addedControlNets: 0, @@ -54,22 +60,31 @@ export const addControlNets = async ( return result; }; +type AddT2IAdaptersArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddT2IAdaptersResult = { addedT2IAdapters: number; }; -export const addT2IAdapters = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 't2i_adapter'); +export const addT2IAdapters = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddT2IAdaptersArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 't2i_adapter') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddT2IAdaptersResult = { addedT2IAdapters: 0, @@ -145,11 +160,3 @@ const addT2IAdapterToGraph = ( g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item'); }; - -const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model - const hasModel = Boolean(controlAdapter.model); - // Model must match the current base model - const modelMatchesBase = controlAdapter.model?.base === base; - return hasModel && modelMatchesBase; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts index 81b98f3ef5..0a3a43a018 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts @@ -1,19 +1,23 @@ import type { CanvasReferenceImageState } from 'features/controlLayers/store/types'; +import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import type { BaseModelType, Invocation } from 'services/api/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; type AddIPAdaptersResult = { addedIPAdapters: number; }; -export const addIPAdapters = ( - ipAdapters: CanvasReferenceImageState[], - g: Graph, - collector: Invocation<'collect'>, - base: BaseModelType -): AddIPAdaptersResult => { - const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base)); +type AddIPAdaptersArg = { + entities: CanvasReferenceImageState[]; + g: Graph; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + +export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => { + const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0); const result: AddIPAdaptersResult = { addedIPAdapters: 0, @@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item'); }; - -const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image - const hasModel = Boolean(ipAdapter.model); - const modelMatchesBase = ipAdapter.model?.base === base; - const hasImage = Boolean(ipAdapter.image); - return isEnabled && hasModel && modelMatchesBase && hasImage; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index dcce2046da..a1921200d5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone'; import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; -import type { - CanvasRegionalGuidanceState, - IPAdapterConfig, - Rect, - RegionalGuidanceReferenceImageState, -} from 'features/controlLayers/store/types'; +import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types'; +import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { serializeError } from 'serialize-error'; -import type { BaseModelType, Invocation } from 'services/api/types'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); @@ -23,19 +20,26 @@ type AddedRegionResult = { addedIPAdapters: number; }; -const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => { - const isEnabled = rg.isEnabled; - const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt); - const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0; - return isEnabled && (hasTextPrompt || hasIPAdapter); +type AddRegionsArg = { + manager: CanvasManager; + regions: CanvasRegionalGuidanceState[]; + g: Graph; + bbox: Rect; + model: ParameterModel; + posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null; + posCondCollect: Invocation<'collect'>; + negCondCollect: Invocation<'collect'> | null; + ipAdapterCollect: Invocation<'collect'>; }; /** * Adds regional guidance to the graph + * @param manager The canvas manager * @param regions Array of regions to add * @param g The graph to add the layers to - * @param base The base model type - * @param denoise The main denoise node + * @param bbox The bounding box + * @param model The main model * @param posCond The positive conditioning node * @param negCond The negative conditioning node * @param posCondCollect The positive conditioning collector @@ -44,22 +48,28 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => * @returns A promise that resolves to the regions that were successfully added to the graph */ -export const addRegions = async ( - manager: CanvasManager, - regions: CanvasRegionalGuidanceState[], - g: Graph, - bbox: Rect, - base: BaseModelType, - denoise: Invocation<'denoise_latents'>, - posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, - negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, - posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'>, - ipAdapterCollect: Invocation<'collect'> -): Promise => { - const isSDXL = base === 'sdxl'; +export const addRegions = async ({ + manager, + regions, + g, + bbox, + model, + posCond, + negCond, + posCondCollect, + negCondCollect, + ipAdapterCollect, +}: AddRegionsArg): Promise => { + const isSDXL = model.base === 'sdxl'; + const isFLUX = model.base === 'flux'; + + const validRegions = regions.filter((rg) => { + if (!rg.isEnabled) { + return false; + } + return getRegionalGuidanceWarnings(rg, model).length === 0; + }); - const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const results: AddedRegionResult[] = []; for (const region of validRegions) { @@ -94,20 +104,27 @@ export const addRegions = async ( if (region.positivePrompt) { // The main positive conditioning node result.addedPositivePrompt = true; - const regionalPosCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - } - ); + let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? + }); + } else if (isFLUX) { + regionalPosCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask'); // Connect the conditioning to the collector @@ -115,38 +132,55 @@ export const addRegions = async ( // Copy the connections to the "global" positive conditioning node to the regional cond if (posCond.type === 'compel') { for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'sdxl_compel_prompt') { + for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCond.id; g.addEdgeFromObj(clone); } } else { - for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalPosCond.id; - g.addEdgeFromObj(clone); - } + assert(false, 'Unsupported positive conditioning node type.'); } } if (region.negativePrompt) { - result.addedNegativePrompt = true; + assert(negCond, 'Negative conditioning node is required if there is a negative prompt'); + assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt'); + // The main negative conditioning node - const regionalNegCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - style: region.negativePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - } - ); + result.addedNegativePrompt = true; + let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalNegCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + style: region.negativePrompt, + }); + } else if (isFLUX) { + regionalNegCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } else { + regionalNegCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } + // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask'); // Connect the conditioning to the collector @@ -158,17 +192,27 @@ export const addRegions = async ( clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } - } else { + } else if (negCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } + } else if (negCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported negative conditioning node type.'); } } // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node if (region.autoNegative && region.positivePrompt) { + assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting'); + result.addedAutoNegativePositivePrompt = true; // We re-use the mask image, but invert it when converting to tensor const invertTensorMask = g.addNode({ @@ -178,20 +222,27 @@ export const addRegions = async ( // Connect the OG mask image to the inverted mask-to-tensor node g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask'); // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt - const regionalPosCondInverted = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - style: region.positivePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - } - ); + let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCondInverted = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + style: region.positivePrompt, + }); + } else if (isFLUX) { + regionalPosCondInverted = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCondInverted = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } // Connect the inverted mask to the conditioning g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask'); // Connect the conditioning to the negative collector @@ -203,20 +254,26 @@ export const addRegions = async ( clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } - } else { + } else if (posCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported positive conditioning node type.'); } } - const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) => - isValidIPAdapter(ipAdapter, base) - ); + for (const { id, ipAdapter } of region.referenceImages) { + assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.'); - for (const { id, ipAdapter } of validRGIPAdapters) { result.addedIPAdapters++; const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; assert(model, 'IP Adapter model is required'); @@ -248,11 +305,3 @@ export const addRegions = async ( return results; }; - -const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image - const hasModel = Boolean(ipAdapter.model); - const modelMatchesBase = ipAdapter.model?.base === base; - const hasImage = Boolean(ipAdapter.image); - return hasModel && modelMatchesBase && hasImage; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index d893760f3c..885954e995 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; +import { addRegions } from 'features/nodes/util/graph/generation/addRegions'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; @@ -79,7 +80,10 @@ export const buildFLUXGraph = async ( id: getPrefixedId('flux_text_encoder'), prompt: positivePrompt, }); - + const posCondCollect = g.addNode({ + type: 'collect', + id: getPrefixedId('pos_cond_collect'), + }); const denoise = g.addNode({ type: 'flux_denoise', id: getPrefixedId('flux_denoise'), @@ -104,13 +108,12 @@ export const buildFLUXGraph = async ( g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder'); g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning'); + g.addEdge(denoise, 'latents', l2i, 'latents'); addFLUXLoRAs(state, g, denoise, modelLoader, posCond); - g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning'); - - g.addEdge(denoise, 'latents', l2i, 'latents'); - const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); assert(modelConfig.base === 'flux'); @@ -196,31 +199,50 @@ export const buildFLUXGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { g.deleteNode(controlNetCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); - const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters; + const regionsResult = await addRegions({ + manager, + regions: canvas.regionalGuidance.entities, + g, + bbox: canvas.bbox.rect, + model: modelConfig, + posCond, + negCond: null, + posCondCollect, + negCondCollect: null, + ipAdapterCollect, + }); + + const totalIPAdaptersAdded = + ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 4008fd05d6..7522227007 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -227,14 +227,14 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -245,46 +245,50 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); - - const regionsResult = await addRegions( - manager, - canvas.regionalGuidance.entities, + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, g, - canvas.bbox.rect, - modelConfig.base, - denoise, + collector: ipAdapterCollect, + model: modelConfig, + }); + + const regionsResult = await addRegions({ + manager, + regions: canvas.regionalGuidance.entities, + g, + bbox: canvas.bbox.rect, + model: modelConfig, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 37ab697522..9357a291b4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -232,14 +232,14 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -250,46 +250,50 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); - - const regionsResult = await addRegions( - manager, - canvas.regionalGuidance.entities, + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, g, - canvas.bbox.rect, - modelConfig.base, - denoise, + collector: ipAdapterCollect, + model: modelConfig, + }); + + const regionsResult = await addRegions({ + manager, + regions: canvas.regionalGuidance.entities, + g, + bbox: canvas.bbox.rect, + model: modelConfig, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 0d14bba73d..0af16bdf78 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import type { CanvasState } from 'features/controlLayers/store/types'; +import { + getControlLayerWarnings, + getGlobalReferenceImageWarnings, + getInpaintMaskWarnings, + getRasterLayerWarnings, + getRegionalGuidanceWarnings, +} from 'features/controlLayers/store/validators'; import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; @@ -278,17 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - // Must have model - if (!controlLayer.controlAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected')); - } - // Model base must match - if (controlLayer.controlAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); - } + const problems = getControlLayerWarnings(controlLayer, model); + if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -300,23 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - - // Must have model - if (!entity.ipAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); - } - // Model base must match - if (entity.ipAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); - } - // Must have an image - if (!entity.ipAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); - } + const problems = getGlobalReferenceImageWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -328,32 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - // Must have a region - if (entity.objects.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoRegion')); - } - // Must have at least 1 prompt or IP Adapter - if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); - } - entity.referenceImages.forEach(({ ipAdapter }) => { - // Must have model - if (!ipAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); - } - // Model base must match - if (ipAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); - } - // Must have an image - if (!ipAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); - } - }); + const problems = getRegionalGuidanceWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -365,10 +330,25 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; + const problems = getRasterLayerWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); + reasons.push({ prefix, content }); + } + }); + + canvas.inpaintMasks.entities + .filter((entity) => entity.isEnabled) + .forEach((entity, i) => { + const layerLiteral = i18n.t('controlLayers.layer_one'); + const layerNumber = i + 1; + const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); + const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; + const problems = getInpaintMaskWarnings(entity, model); + + if (problems.length) { + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 90adf7517e..8878a4bcfa 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -6564,6 +6564,11 @@ export type components = { * @description The name of conditioning tensor */ conditioning_name: string; + /** + * @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; }; /** * FluxConditioningOutput @@ -6771,15 +6776,17 @@ export type components = { */ transformer?: components["schemas"]["TransformerField"]; /** + * Positive Text Conditioning * @description Positive conditioning tensor * @default null */ - positive_text_conditioning?: components["schemas"]["FluxConditioningField"]; + positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][]; /** + * Negative Text Conditioning * @description Negative conditioning tensor. Can be None if cfg_scale is 1.0. * @default null */ - negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null; + negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null; /** * CFG Scale * @description Classifier-Free Guidance scale @@ -7133,6 +7140,11 @@ export type components = { * @default null */ prompt?: string; + /** + * @description A mask defining the region that this conditioning prompt applies to. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; /** * type * @default flux_text_encoder