From 85c616fa34e21144dcb2d700f24b5f36a84e970f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 20 Nov 2024 18:51:43 +0000 Subject: [PATCH 01/28] WIP - Pass prompt masks to FLUX model during denoising. --- invokeai/app/invocations/fields.py | 5 + invokeai/app/invocations/flux_denoise.py | 180 +++++++++++++----- invokeai/app/invocations/flux_text_encoder.py | 11 +- invokeai/backend/flux/denoise.py | 12 +- invokeai/backend/flux/text_conditioning.py | 32 ++++ 5 files changed, 186 insertions(+), 54 deletions(-) create mode 100644 invokeai/backend/flux/text_conditioning.py diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 5e76931933..16bfcaeed7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") + mask: Optional[TensorField] = Field( + default=None, + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to True.", + ) class SD3ConditioningField(BaseModel): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9e197626b5..9faa1207d2 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -4,6 +4,7 @@ from typing import Callable, Iterator, Optional, Tuple import numpy as np import numpy.typing as npt import torch +import torchvision import torchvision.transforms as tv_transforms from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -42,13 +43,15 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask @invocation( @@ -87,10 +90,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, title="Transformer", ) - positive_text_conditioning: FluxConditioningField = InputField( + positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection ) - negative_text_conditioning: FluxConditioningField | None = InputField( + negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField( default=None, description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.", input=Input.Connection, @@ -139,18 +142,112 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + @staticmethod + def _preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + def _load_text_conditioning( - self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Load the conditioning data. - cond_data = context.conditioning.load(conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) - t5_embeddings = flux_conditioning.t5_embeds - clip_embeddings = flux_conditioning.clip_embeds - return t5_embeddings, clip_embeddings + self, + context: InvocationContext, + cond_field: FluxConditioningField | list[FluxConditioningField], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> list[FluxTextConditioning]: + """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" + # Normalize to a list of FluxConditioningFields. + cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field + + text_conditionings: list[FluxTextConditioning] = [] + for cond_field in cond_list: + # Load the text embeddings. + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + # Load the mask, if provided. + mask: Optional[torch.Tensor] = None + if cond_field.mask is not None: + mask = context.tensors.load(cond_field.mask.tensor_name) + mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) + + text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) + + return text_conditionings + + def _concat_regional_text_conditioning( + self, text_conditionings: list[FluxTextConditioning] + ) -> FluxRegionalTextConditioning: + """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" + concat_t5_embeddings: list[torch.Tensor] = [] + concat_clip_embeddings: list[torch.Tensor] = [] + concat_image_masks: list[torch.Tensor] = [] + concat_t5_embedding_ranges: list[Range] = [] + concat_clip_embedding_ranges: list[Range] = [] + + cur_t5_embedding_len = 0 + cur_clip_embedding_len = 0 + for text_conditioning in text_conditionings: + concat_t5_embeddings.append(text_conditioning.t5_embeddings) + concat_clip_embeddings.append(text_conditioning.clip_embeddings) + + concat_t5_embedding_ranges.append( + Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) + ) + concat_clip_embedding_ranges.append( + Range( + start=cur_clip_embedding_len, + end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], + ) + ) + + concat_image_masks.append(text_conditioning.mask) + + cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] + cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] + + t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) + + # Initialize the txt_ids tensor. + pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape + t5_txt_ids = torch.zeros( + pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() + ) + + return FluxRegionalTextConditioning( + t5_embeddings=t5_embeddings, + clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), + t5_txt_ids=t5_txt_ids, + image_masks=torch.cat(concat_image_masks, dim=1), + t5_embedding_ranges=concat_t5_embedding_ranges, + clip_embedding_ranges=concat_clip_embedding_ranges, + ) def _run_diffusion( self, @@ -158,17 +255,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): ): inference_dtype = torch.bfloat16 - # Load the conditioning data. - pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning( - context, self.positive_text_conditioning.conditioning_name, inference_dtype - ) - neg_t5_embeddings: torch.Tensor | None = None - neg_clip_embeddings: torch.Tensor | None = None - if self.negative_text_conditioning is not None: - neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning( - context, self.negative_text_conditioning.conditioning_name, inference_dtype - ) - # Load the input latents, if provided. init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None if init_latents is not None: @@ -183,6 +269,30 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, seed=self.seed, ) + b, _c, latent_h, latent_w = noise.shape + + # Load the conditioning data. + pos_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.positive_text_conditioning, + latent_height=latent_h, + latent_width=latent_w, + dtype=inference_dtype, + ) + neg_text_conditionings: list[FluxTextConditioning] | None = None + if self.negative_text_conditioning is not None: + neg_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.negative_text_conditioning, + latent_height=latent_h, + latent_width=latent_w, + dtype=inference_dtype, + ) + + pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings) + neg_regional_text_conditioning = ( + self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None + ) transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path @@ -228,20 +338,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): inpaint_mask = self._prep_inpaint_mask(context, x) - b, _c, latent_h, latent_w = x.shape img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype) - pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape - pos_txt_ids = torch.zeros( - pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - neg_txt_ids: torch.Tensor | None = None - if neg_t5_embeddings is not None: - neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape - neg_txt_ids = torch.zeros( - neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - # Pack all latent tensors. init_latents = pack(init_latents) if init_latents is not None else None inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None @@ -338,12 +436,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): model=transformer, img=x, img_ids=img_ids, - txt=pos_t5_embeddings, - txt_ids=pos_txt_ids, - vec=pos_clip_embeddings, - neg_txt=neg_t5_embeddings, - neg_txt_ids=neg_txt_ids, - neg_vec=neg_clip_embeddings, + pos_text_conditioning=pos_regional_text_conditioning, + neg_text_conditioning=neg_regional_text_conditioning, timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index af250f0f3b..cc9c68eca4 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,11 +1,11 @@ from contextlib import ExitStack -from typing import Iterator, Literal, Tuple +from typing import Iterator, Literal, Optional, Tuple import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField +from invokeai.app.invocations.fields import FieldDescriptions, FluxConditioningField, Input, InputField, TensorField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -42,6 +42,9 @@ class FluxTextEncoderInvocation(BaseInvocation): description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." ) prompt: str = InputField(description="Text prompt to encode.") + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> FluxConditioningOutput: @@ -54,7 +57,9 @@ class FluxTextEncoderInvocation(BaseInvocation): ) conditioning_name = context.conditioning.save(conditioning_data) - return FluxConditioningOutput.build(conditioning_name) + return FluxConditioningOutput( + conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask) + ) def _t5_encode(self, context: InvocationContext) -> torch.Tensor: t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index bb0e60409a..c1cb3bbb71 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -10,6 +10,7 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -18,14 +19,8 @@ def denoise( # model input img: torch.Tensor, img_ids: torch.Tensor, - # positive text conditioning - txt: torch.Tensor, - txt_ids: torch.Tensor, - vec: torch.Tensor, - # negative text conditioning - neg_txt: torch.Tensor | None, - neg_txt_ids: torch.Tensor | None, - neg_vec: torch.Tensor | None, + pos_text_conditioning: FluxRegionalTextConditioning, + neg_text_conditioning: FluxRegionalTextConditioning | None, # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], @@ -55,6 +50,7 @@ def denoise( # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: + # FIX(ryand): Revive ControlNet functionality. controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step_index, diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py new file mode 100644 index 0000000000..82d37688d1 --- /dev/null +++ b/invokeai/backend/flux/text_conditioning.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass + +import torch + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range + + +@dataclass +class FluxTextConditioning: + t5_embeddings: torch.Tensor + clip_embeddings: torch.Tensor + mask: torch.Tensor + + +@dataclass +class FluxRegionalTextConditioning: + # Concatenated text embeddings. + t5_embeddings: torch.Tensor + clip_embeddings: torch.Tensor + + t5_txt_ids: torch.Tensor + + # A binary mask indicating the regions of the image that the prompt should be applied to. + # Shape: (1, num_prompts, height, width) + # Dtype: torch.bool + image_masks: torch.Tensor + + # List of ranges that represent the embedding ranges for each mask. + # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i]. + # clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i]. + t5_embedding_ranges: list[Range] + clip_embedding_ranges: list[Range] From fda7aaa7ca74925736bdfe339d08f3ce7ebf29a6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 20 Nov 2024 19:48:04 +0000 Subject: [PATCH 02/28] Pass RegionalPromptingExtension down to the CustomDoubleStreamBlockProcessor in FLUX. --- invokeai/app/invocations/flux_denoise.py | 160 +++++------------- .../backend/flux/custom_block_processor.py | 2 + invokeai/backend/flux/denoise.py | 29 ++-- .../regional_prompting_extension.py | 96 +++++++++++ invokeai/backend/flux/model.py | 3 + 5 files changed, 159 insertions(+), 131 deletions(-) create mode 100644 invokeai/backend/flux/extensions/regional_prompting_extension.py diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9faa1207d2..67bcfc785f 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -4,7 +4,6 @@ from typing import Callable, Iterator, Optional, Tuple import numpy as np import numpy.typing as npt import torch -import torchvision import torchvision.transforms as tv_transforms from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -31,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux @@ -43,15 +43,14 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) -from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning +from invokeai.backend.flux.text_conditioning import FluxTextConditioning from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.mask import to_standard_float_mask @invocation( @@ -142,113 +141,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) - @staticmethod - def _preprocess_regional_prompt_mask( - mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype - ) -> torch.Tensor: - """Preprocess a regional prompt mask to match the target height and width. - If mask is None, returns a mask of all ones with the target height and width. - If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. - - Returns: - torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). - """ - - if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=dtype) - - mask = to_standard_float_mask(mask, out_dtype=dtype) - - tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST - ) - - # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). - mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) - resized_mask = tf(mask) - return resized_mask - - def _load_text_conditioning( - self, - context: InvocationContext, - cond_field: FluxConditioningField | list[FluxConditioningField], - latent_height: int, - latent_width: int, - dtype: torch.dtype, - ) -> list[FluxTextConditioning]: - """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" - # Normalize to a list of FluxConditioningFields. - cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field - - text_conditionings: list[FluxTextConditioning] = [] - for cond_field in cond_list: - # Load the text embeddings. - cond_data = context.conditioning.load(cond_field.conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) - t5_embeddings = flux_conditioning.t5_embeds - clip_embeddings = flux_conditioning.clip_embeds - - # Load the mask, if provided. - mask: Optional[torch.Tensor] = None - if cond_field.mask is not None: - mask = context.tensors.load(cond_field.mask.tensor_name) - mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) - - text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) - - return text_conditionings - - def _concat_regional_text_conditioning( - self, text_conditionings: list[FluxTextConditioning] - ) -> FluxRegionalTextConditioning: - """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" - concat_t5_embeddings: list[torch.Tensor] = [] - concat_clip_embeddings: list[torch.Tensor] = [] - concat_image_masks: list[torch.Tensor] = [] - concat_t5_embedding_ranges: list[Range] = [] - concat_clip_embedding_ranges: list[Range] = [] - - cur_t5_embedding_len = 0 - cur_clip_embedding_len = 0 - for text_conditioning in text_conditionings: - concat_t5_embeddings.append(text_conditioning.t5_embeddings) - concat_clip_embeddings.append(text_conditioning.clip_embeddings) - - concat_t5_embedding_ranges.append( - Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) - ) - concat_clip_embedding_ranges.append( - Range( - start=cur_clip_embedding_len, - end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], - ) - ) - - concat_image_masks.append(text_conditioning.mask) - - cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] - cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] - - t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) - - # Initialize the txt_ids tensor. - pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape - t5_txt_ids = torch.zeros( - pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() - ) - - return FluxRegionalTextConditioning( - t5_embeddings=t5_embeddings, - clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), - t5_txt_ids=t5_txt_ids, - image_masks=torch.cat(concat_image_masks, dim=1), - t5_embedding_ranges=concat_t5_embedding_ranges, - clip_embedding_ranges=concat_clip_embedding_ranges, - ) - def _run_diffusion( self, context: InvocationContext, @@ -288,10 +180,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): latent_width=latent_w, dtype=inference_dtype, ) - - pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings) - neg_regional_text_conditioning = ( - self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None + pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings) + neg_regional_prompting_extension = ( + RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings) + if neg_text_conditionings + else None ) transformer_info = context.models.load(self.transformer.transformer) @@ -436,8 +329,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): model=transformer, img=x, img_ids=img_ids, - pos_text_conditioning=pos_regional_text_conditioning, - neg_text_conditioning=neg_regional_text_conditioning, + pos_regional_prompting_extension=pos_regional_prompting_extension, + neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, @@ -451,6 +344,39 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): x = unpack(x.float(), self.height, self.width) return x + def _load_text_conditioning( + self, + context: InvocationContext, + cond_field: FluxConditioningField | list[FluxConditioningField], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> list[FluxTextConditioning]: + """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" + # Normalize to a list of FluxConditioningFields. + cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field + + text_conditionings: list[FluxTextConditioning] = [] + for cond_field in cond_list: + # Load the text embeddings. + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + # Load the mask, if provided. + mask: Optional[torch.Tensor] = None + if cond_field.mask is not None: + mask = context.tensors.load(cond_field.mask.tensor_name) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) + + text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) + + return text_conditionings + @classmethod def prep_cfg_scale( cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index e0c7779e93..ae339cbd0e 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -1,6 +1,7 @@ import einops import torch +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.math import attention from invokeai.backend.flux.modules.layers import DoubleStreamBlock @@ -63,6 +64,7 @@ class CustomDoubleStreamBlockProcessor: vec: torch.Tensor, pe: torch.Tensor, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> tuple[torch.Tensor, torch.Tensor]: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index c1cb3bbb71..ae541ba7d8 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -7,10 +7,10 @@ from tqdm import tqdm from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -19,8 +19,8 @@ def denoise( # model input img: torch.Tensor, img_ids: torch.Tensor, - pos_text_conditioning: FluxRegionalTextConditioning, - neg_text_conditioning: FluxRegionalTextConditioning | None, + pos_regional_prompting_extension: RegionalPromptingExtension, + neg_regional_prompting_extension: RegionalPromptingExtension | None, # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], @@ -50,16 +50,16 @@ def denoise( # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: - # FIX(ryand): Revive ControlNet functionality. + # TODO(ryand): Think about how to handle regional prompting with ControlNet. controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step_index, total_num_timesteps=total_steps, img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, ) @@ -74,9 +74,9 @@ def denoise( pred = model( img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, @@ -84,6 +84,7 @@ def denoise( controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ip_adapter_extensions=pos_ip_adapter_extensions, + regional_prompting_extension=pos_regional_prompting_extension, ) step_cfg_scale = cfg_scale[step_index] @@ -93,15 +94,15 @@ def denoise( # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance # on systems with sufficient VRAM. - if neg_txt is None or neg_txt_ids is None or neg_vec is None: + if neg_regional_prompting_extension is None: raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") neg_pred = model( img=img, img_ids=img_ids, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py new file mode 100644 index 0000000000..21b2279b14 --- /dev/null +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -0,0 +1,96 @@ +from typing import Optional + +import torch +import torchvision + +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask + + +class RegionalPromptingExtension: + """A class for managing regional prompting with FLUX.""" + + def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): + self.regional_text_conditioning = regional_text_conditioning + + @classmethod + def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): + return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning)) + + @classmethod + def _concat_regional_text_conditioning( + cls, + text_conditionings: list[FluxTextConditioning], + ) -> FluxRegionalTextConditioning: + """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" + concat_t5_embeddings: list[torch.Tensor] = [] + concat_clip_embeddings: list[torch.Tensor] = [] + concat_image_masks: list[torch.Tensor] = [] + concat_t5_embedding_ranges: list[Range] = [] + concat_clip_embedding_ranges: list[Range] = [] + + cur_t5_embedding_len = 0 + cur_clip_embedding_len = 0 + for text_conditioning in text_conditionings: + concat_t5_embeddings.append(text_conditioning.t5_embeddings) + concat_clip_embeddings.append(text_conditioning.clip_embeddings) + + concat_t5_embedding_ranges.append( + Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) + ) + concat_clip_embedding_ranges.append( + Range( + start=cur_clip_embedding_len, + end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], + ) + ) + + concat_image_masks.append(text_conditioning.mask) + + cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] + cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] + + t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) + + # Initialize the txt_ids tensor. + pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape + t5_txt_ids = torch.zeros( + pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() + ) + + return FluxRegionalTextConditioning( + t5_embeddings=t5_embeddings, + clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), + t5_txt_ids=t5_txt_ids, + image_masks=torch.cat(concat_image_masks, dim=1), + t5_embedding_ranges=concat_t5_embedding_ranges, + clip_embedding_ranges=concat_clip_embedding_ranges, + ) + + @staticmethod + def preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 0dadacd8fe..a0c44f0039 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.modules.layers import ( DoubleStreamBlock, @@ -95,6 +96,7 @@ class Flux(nn.Module): controlnet_double_block_residuals: list[Tensor] | None, controlnet_single_block_residuals: list[Tensor] | None, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -128,6 +130,7 @@ class Flux(nn.Module): vec=vec, pe=pe, ip_adapter_extensions=ip_adapter_extensions, + regional_prompting_extension=regional_prompting_extension, ) if controlnet_double_block_residuals is not None: From bad11495046690cdbfd44155e46d7e5df655e10e Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 20 Nov 2024 22:29:36 +0000 Subject: [PATCH 03/28] WIP - add rough logic for preparing the FLUX regional prompting attention mask. --- .../regional_prompting_extension.py | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 21b2279b14..43a3cc4395 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -10,7 +10,10 @@ from invokeai.backend.util.mask import to_standard_float_mask class RegionalPromptingExtension: - """A class for managing regional prompting with FLUX.""" + """A class for managing regional prompting with FLUX. + + Implementation inspired by: https://arxiv.org/pdf/2411.02395 + """ def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): self.regional_text_conditioning = regional_text_conditioning @@ -19,6 +22,51 @@ class RegionalPromptingExtension: def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning)) + def _prepare_attn_mask(self) -> torch.Tensor: + device = self.regional_text_conditioning.image_masks[0].device + # img_seq_len = latent_height * latent_width + img_seq_len = ( + self.regional_text_conditioning.image_masks.shape[-1] + * self.regional_text_conditioning.image_masks.shape[-2] + ) + txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1] + + # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # Concatenation happens in the following order: [txt_seq, img_seq]. + # There are 4 portions of the attention mask to consider as we prepare it: + # 1. txt attends to itself + # 2. txt attends to corresponding regional img + # 3. regional img attends to corresponding txt + # 4. regional img attends to itself + + # Initialize empty attention mask. + regional_attention_mask = torch.zeros( + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool + ) + + for i in range(len(self.regional_text_conditioning.t5_embeddings)): + image_mask = self.regional_text_conditioning.image_masks[i].flatten() + t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i] + + # 1. txt attends to itself + regional_attention_mask[ + t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + ] = True + + # 2. txt attends to corresponding regional img + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask + + # 3. regional img attends to corresponding txt + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask + + # 4. regional img attends to itself + # TODO(ryand): Make sure that broadcasting works as expected. + regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T + + return regional_attention_mask + @classmethod def _concat_regional_text_conditioning( cls, From 20356c07466a6f58e1e55521453af66fe6f76655 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 21 Nov 2024 22:46:25 +0000 Subject: [PATCH 04/28] Fixup the logic for preparing FLUX regional prompt attention masks. --- invokeai/app/invocations/flux_denoise.py | 24 ++++++----- invokeai/app/invocations/flux_text_encoder.py | 11 ++++- .../backend/flux/custom_block_processor.py | 9 ++++- .../regional_prompting_extension.py | 40 +++++++++++-------- invokeai/backend/flux/math.py | 4 +- invokeai/backend/flux/text_conditioning.py | 2 +- 6 files changed, 55 insertions(+), 35 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 67bcfc785f..1d8c149759 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -162,13 +162,15 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): seed=self.seed, ) b, _c, latent_h, latent_w = noise.shape + packed_h = latent_h // 2 + packed_w = latent_w // 2 # Load the conditioning data. pos_text_conditionings = self._load_text_conditioning( context=context, cond_field=self.positive_text_conditioning, - latent_height=latent_h, - latent_width=latent_w, + packed_height=packed_h, + packed_width=packed_w, dtype=inference_dtype, ) neg_text_conditionings: list[FluxTextConditioning] | None = None @@ -176,8 +178,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): neg_text_conditionings = self._load_text_conditioning( context=context, cond_field=self.negative_text_conditioning, - latent_height=latent_h, - latent_width=latent_w, + packed_height=packed_h, + packed_width=packed_w, dtype=inference_dtype, ) pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings) @@ -191,10 +193,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): is_schnell = "schnell" in transformer_info.config.config_path # Calculate the timestep schedule. - image_seq_len = noise.shape[-1] * noise.shape[-2] // 4 timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=image_seq_len, + image_seq_len=packed_h * packed_w, shift=not is_schnell, ) @@ -239,8 +240,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): noise = pack(noise) x = pack(x) - # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len correctly. - assert image_seq_len == x.shape[1] + # Now that we have 'packed' the latent tensors, verify that we calculated the image_seq_len, packed_h, and + # packed_w correctly. + assert packed_h * packed_w == x.shape[1] # Prepare inpaint extension. inpaint_extension: InpaintExtension | None = None @@ -348,8 +350,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): self, context: InvocationContext, cond_field: FluxConditioningField | list[FluxConditioningField], - latent_height: int, - latent_width: int, + packed_height: int, + packed_width: int, dtype: torch.dtype, ) -> list[FluxTextConditioning]: """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" @@ -371,7 +373,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): mask: Optional[torch.Tensor] = None if cond_field.mask is not None: mask = context.tensors.load(cond_field.mask.tensor_name) - mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, packed_height, packed_width, dtype) text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index cc9c68eca4..1eb0fea62e 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -5,7 +5,14 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.fields import FieldDescriptions, FluxConditioningField, Input, InputField, TensorField +from invokeai.app.invocations.fields import ( + FieldDescriptions, + FluxConditioningField, + Input, + InputField, + TensorField, + UIComponent, +) from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -41,7 +48,7 @@ class FluxTextEncoderInvocation(BaseInvocation): t5_max_seq_len: Literal[256, 512] = InputField( description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." ) - prompt: str = InputField(description="Text prompt to encode.") + prompt: str = InputField(description="Text prompt to encode.", ui_component=UIComponent.Textarea) mask: Optional[TensorField] = InputField( default=None, description="A mask defining the region that this conditioning prompt applies to." ) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index ae339cbd0e..02c3447b1f 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -14,7 +14,12 @@ class CustomDoubleStreamBlockProcessor: @staticmethod def _double_stream_block_forward( - block: DoubleStreamBlock, img: torch.Tensor, txt: torch.Tensor, vec: torch.Tensor, pe: torch.Tensor + block: DoubleStreamBlock, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + attn_mask: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """This function is a direct copy of DoubleStreamBlock.forward(), but it returns some of the intermediate values. @@ -41,7 +46,7 @@ class CustomDoubleStreamBlockProcessor: k = torch.cat((txt_k, img_k), dim=2) v = torch.cat((txt_v, img_v), dim=2) - attn = attention(q, k, v, pe=pe) + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] # calculate the img bloks diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 43a3cc4395..3b782ff884 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -17,6 +17,7 @@ class RegionalPromptingExtension: def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): self.regional_text_conditioning = regional_text_conditioning + self.attn_mask = self._prepare_attn_mask() @classmethod def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): @@ -24,11 +25,8 @@ class RegionalPromptingExtension: def _prepare_attn_mask(self) -> torch.Tensor: device = self.regional_text_conditioning.image_masks[0].device - # img_seq_len = latent_height * latent_width - img_seq_len = ( - self.regional_text_conditioning.image_masks.shape[-1] - * self.regional_text_conditioning.image_masks.shape[-2] - ) + # img_seq_len = packed_height * packed_width + img_seq_len = self.regional_text_conditioning.image_masks.shape[2] txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1] # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. @@ -44,8 +42,8 @@ class RegionalPromptingExtension: (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool ) - for i in range(len(self.regional_text_conditioning.t5_embeddings)): - image_mask = self.regional_text_conditioning.image_masks[i].flatten() + for i in range(len(self.regional_text_conditioning.t5_embedding_ranges)): + image_mask = self.regional_text_conditioning.image_masks[0, i] t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i] # 1. txt attends to itself @@ -54,15 +52,19 @@ class RegionalPromptingExtension: ] = True # 2. txt attends to corresponding regional img - # TODO(ryand): Make sure that broadcasting works as expected. - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( + 1, img_seq_len + ) # 3. regional img attends to corresponding txt - # TODO(ryand): Make sure that broadcasting works as expected. - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( + img_seq_len, 1 + ) # 4. regional img attends to itself - # TODO(ryand): Make sure that broadcasting works as expected. + image_mask = image_mask.view(img_seq_len, 1) regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T return regional_attention_mask @@ -119,26 +121,30 @@ class RegionalPromptingExtension: @staticmethod def preprocess_regional_prompt_mask( - mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + packed_height and packed_width are the target height and width of the mask in the 'packed' latent space. + Returns: - torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + torch.Tensor: The processed mask. shape: (1, 1, packed_height * packed_width). """ if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=dtype) + return torch.ones((1, 1, packed_height * packed_width), dtype=dtype) mask = to_standard_float_mask(mask, out_dtype=dtype) tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + (packed_height, packed_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST ) # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) resized_mask = tf(mask) - return resized_mask + + # Flatten the height and width dimensions into a single image_seq_len dimension. + return resized_mask.flatten(start_dim=2) diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py index 0fac7a1d16..260243a35d 100644 --- a/invokeai/backend/flux/math.py +++ b/invokeai/backend/flux/math.py @@ -5,10 +5,10 @@ from einops import rearrange from torch import Tensor -def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor: +def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None) -> Tensor: q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) x = rearrange(x, "B H L D -> B L (H D)") return x diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py index 82d37688d1..5276d1e089 100644 --- a/invokeai/backend/flux/text_conditioning.py +++ b/invokeai/backend/flux/text_conditioning.py @@ -21,7 +21,7 @@ class FluxRegionalTextConditioning: t5_txt_ids: torch.Tensor # A binary mask indicating the regions of the image that the prompt should be applied to. - # Shape: (1, num_prompts, height, width) + # Shape: (1, num_prompts, image_seq_len) # Dtype: torch.bool image_masks: torch.Tensor From 2c23b8414c4582404cfffe8bdf0b95e014495106 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 22 Nov 2024 23:01:43 +0000 Subject: [PATCH 05/28] Use a single global CLIP embedding for FLUX regional guidance. --- .../extensions/regional_prompting_extension.py | 15 ++------------- invokeai/backend/flux/text_conditioning.py | 10 ++++++---- 2 files changed, 8 insertions(+), 17 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 3b782ff884..4f0df1a9c4 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -76,31 +76,20 @@ class RegionalPromptingExtension: ) -> FluxRegionalTextConditioning: """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" concat_t5_embeddings: list[torch.Tensor] = [] - concat_clip_embeddings: list[torch.Tensor] = [] concat_image_masks: list[torch.Tensor] = [] concat_t5_embedding_ranges: list[Range] = [] - concat_clip_embedding_ranges: list[Range] = [] cur_t5_embedding_len = 0 - cur_clip_embedding_len = 0 for text_conditioning in text_conditionings: concat_t5_embeddings.append(text_conditioning.t5_embeddings) - concat_clip_embeddings.append(text_conditioning.clip_embeddings) concat_t5_embedding_ranges.append( Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) ) - concat_clip_embedding_ranges.append( - Range( - start=cur_clip_embedding_len, - end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], - ) - ) concat_image_masks.append(text_conditioning.mask) cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] - cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) @@ -112,11 +101,11 @@ class RegionalPromptingExtension: return FluxRegionalTextConditioning( t5_embeddings=t5_embeddings, - clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), + # HACK(ryand): Be smarter about how we select which CLIP embedding to use. + clip_embeddings=text_conditionings[0].clip_embeddings, t5_txt_ids=t5_txt_ids, image_masks=torch.cat(concat_image_masks, dim=1), t5_embedding_ranges=concat_t5_embedding_ranges, - clip_embedding_ranges=concat_clip_embedding_ranges, ) @staticmethod diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py index 5276d1e089..fdfa884444 100644 --- a/invokeai/backend/flux/text_conditioning.py +++ b/invokeai/backend/flux/text_conditioning.py @@ -15,11 +15,15 @@ class FluxTextConditioning: @dataclass class FluxRegionalTextConditioning: # Concatenated text embeddings. + # Shape: (1, concatenated_txt_seq_len, 4096) t5_embeddings: torch.Tensor - clip_embeddings: torch.Tensor - + # Shape: (1, concatenated_txt_seq_len, 3) t5_txt_ids: torch.Tensor + # Global CLIP embeddings. + # Shape: (1, 768) + clip_embeddings: torch.Tensor + # A binary mask indicating the regions of the image that the prompt should be applied to. # Shape: (1, num_prompts, image_seq_len) # Dtype: torch.bool @@ -27,6 +31,4 @@ class FluxRegionalTextConditioning: # List of ranges that represent the embedding ranges for each mask. # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i]. - # clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i]. t5_embedding_ranges: list[Range] - clip_embedding_ranges: list[Range] From 3741a6f5e05f11f7a71b5edcc116fe3abd773583 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 25 Nov 2024 16:02:03 +0000 Subject: [PATCH 06/28] Fix device handling for regional masks and apply the attention mask in the FLUX double stream block. --- invokeai/app/invocations/flux_denoise.py | 10 ++++++++-- invokeai/backend/flux/custom_block_processor.py | 4 +++- .../flux/extensions/regional_prompting_extension.py | 4 ++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 1d8c149759..35f15f05d8 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -172,6 +172,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): packed_height=packed_h, packed_width=packed_w, dtype=inference_dtype, + device=TorchDevice.choose_torch_device(), ) neg_text_conditionings: list[FluxTextConditioning] | None = None if self.negative_text_conditioning is not None: @@ -181,6 +182,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): packed_height=packed_h, packed_width=packed_w, dtype=inference_dtype, + device=TorchDevice.choose_torch_device(), ) pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings) neg_regional_prompting_extension = ( @@ -353,6 +355,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): packed_height: int, packed_width: int, dtype: torch.dtype, + device: torch.device, ) -> list[FluxTextConditioning]: """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" # Normalize to a list of FluxConditioningFields. @@ -365,7 +368,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): assert len(cond_data.conditionings) == 1 flux_conditioning = cond_data.conditionings[0] assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) + flux_conditioning = flux_conditioning.to(dtype=dtype, device=device) t5_embeddings = flux_conditioning.t5_embeds clip_embeddings = flux_conditioning.clip_embeds @@ -373,7 +376,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): mask: Optional[torch.Tensor] = None if cond_field.mask is not None: mask = context.tensors.load(cond_field.mask.tensor_name) - mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, packed_height, packed_width, dtype) + mask = mask.to(device=device) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask( + mask, packed_height, packed_width, dtype, device + ) text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index 02c3447b1f..edc57d624e 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -74,7 +74,9 @@ class CustomDoubleStreamBlockProcessor: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support """ - img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward(block, img, txt, vec, pe) + img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward( + block, img, txt, vec, pe, attn_mask=regional_prompting_extension.attn_mask + ) # Apply IP-Adapter conditioning. for ip_adapter_extension in ip_adapter_extensions: diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 4f0df1a9c4..b61024e67b 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -110,7 +110,7 @@ class RegionalPromptingExtension: @staticmethod def preprocess_regional_prompt_mask( - mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype + mask: Optional[torch.Tensor], packed_height: int, packed_width: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. @@ -123,7 +123,7 @@ class RegionalPromptingExtension: """ if mask is None: - return torch.ones((1, 1, packed_height * packed_width), dtype=dtype) + return torch.ones((1, 1, packed_height * packed_width), dtype=dtype, device=device) mask = to_standard_float_mask(mask, out_dtype=dtype) From 94c088300f926f5f8db1ac383b8be07fb4fea0bb Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 25 Nov 2024 20:15:04 +0000 Subject: [PATCH 07/28] Be smarter about selecting the global CLIP embedding for FLUX regional prompting. --- .../extensions/regional_prompting_extension.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index b61024e67b..b7f4616782 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -64,8 +64,9 @@ class RegionalPromptingExtension: ) # 4. regional img attends to itself - image_mask = image_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T + # image_mask = image_mask.view(img_seq_len, 1) + # regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T + regional_attention_mask[txt_seq_len:, txt_seq_len:] = True return regional_attention_mask @@ -79,6 +80,15 @@ class RegionalPromptingExtension: concat_image_masks: list[torch.Tensor] = [] concat_t5_embedding_ranges: list[Range] = [] + # Choose global CLIP embedding. + # Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use + # the first prompt's CLIP embedding. + global_clip_embedding: torch.Tensor = text_conditionings[0].clip_embeddings + for text_conditioning in text_conditionings: + if text_conditioning.mask is None: + global_clip_embedding = text_conditioning.clip_embeddings + break + cur_t5_embedding_len = 0 for text_conditioning in text_conditionings: concat_t5_embeddings.append(text_conditioning.t5_embeddings) @@ -101,8 +111,7 @@ class RegionalPromptingExtension: return FluxRegionalTextConditioning( t5_embeddings=t5_embeddings, - # HACK(ryand): Be smarter about how we select which CLIP embedding to use. - clip_embeddings=text_conditionings[0].clip_embeddings, + clip_embeddings=global_clip_embedding, t5_txt_ids=t5_txt_ids, image_masks=torch.cat(concat_image_masks, dim=1), t5_embedding_ranges=concat_t5_embedding_ranges, From 53abdde24202b2cff5b6fddd8f688a91a6ec82e1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 25 Nov 2024 22:04:23 +0000 Subject: [PATCH 08/28] Update Flux RegionalPromptingExtension to prepare both a mask with restricted image self-attention and a mask with unrestricted image self attention. --- invokeai/app/invocations/flux_denoise.py | 12 ++- .../backend/flux/custom_block_processor.py | 8 +- .../regional_prompting_extension.py | 93 +++++++++++++------ invokeai/backend/flux/text_conditioning.py | 12 ++- 4 files changed, 87 insertions(+), 38 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 35f15f05d8..f2a2dd586a 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -184,9 +184,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, device=TorchDevice.choose_torch_device(), ) - pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings) + pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning( + pos_text_conditionings, img_seq_len=packed_h * packed_w + ) neg_regional_prompting_extension = ( - RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings) + RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings, img_seq_len=packed_h * packed_w) if neg_text_conditionings else None ) @@ -377,9 +379,9 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): if cond_field.mask is not None: mask = context.tensors.load(cond_field.mask.tensor_name) mask = mask.to(device=device) - mask = RegionalPromptingExtension.preprocess_regional_prompt_mask( - mask, packed_height, packed_width, dtype, device - ) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask( + mask, packed_height, packed_width, dtype, device + ) text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index edc57d624e..8af4912856 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -74,8 +74,14 @@ class CustomDoubleStreamBlockProcessor: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support """ + img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward( - block, img, txt, vec, pe, attn_mask=regional_prompting_extension.attn_mask + block, + img, + txt, + vec, + pe, + attn_mask=regional_prompting_extension.attn_mask_with_unrestricted_img_self_attn, ) # Apply IP-Adapter conditioning. diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index b7f4616782..1986c7f4aa 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -15,19 +15,54 @@ class RegionalPromptingExtension: Implementation inspired by: https://arxiv.org/pdf/2411.02395 """ - def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): + def __init__( + self, + regional_text_conditioning: FluxRegionalTextConditioning, + attn_mask_with_restricted_img_self_attn: torch.Tensor | None = None, + attn_mask_with_unrestricted_img_self_attn: torch.Tensor | None = None, + ): self.regional_text_conditioning = regional_text_conditioning - self.attn_mask = self._prepare_attn_mask() + self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn + self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn @classmethod - def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): - return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning)) + def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int): + """Create a RegionalPromptingExtension from a list of text conditionings. - def _prepare_attn_mask(self) -> torch.Tensor: - device = self.regional_text_conditioning.image_masks[0].device - # img_seq_len = packed_height * packed_width - img_seq_len = self.regional_text_conditioning.image_masks.shape[2] - txt_seq_len = self.regional_text_conditioning.t5_embeddings.shape[1] + Args: + text_conditioning (list[FluxTextConditioning]): The text conditionings to use for regional prompting. + img_seq_len (int): The image sequence length (i.e. packed_height * packed_width). + """ + regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning) + attn_mask_with_restricted_img_self_attn = cls._prepare_attn_mask( + regional_text_conditioning, img_seq_len, restrict_img_self_attn=True + ) + attn_mask_with_unrestricted_img_self_attn = cls._prepare_attn_mask( + regional_text_conditioning, img_seq_len, restrict_img_self_attn=False + ) + return cls( + regional_text_conditioning=regional_text_conditioning, + attn_mask_with_restricted_img_self_attn=attn_mask_with_restricted_img_self_attn, + attn_mask_with_unrestricted_img_self_attn=attn_mask_with_unrestricted_img_self_attn, + ) + + @classmethod + def _prepare_attn_mask( + cls, + regional_text_conditioning: FluxRegionalTextConditioning, + img_seq_len: int, + restrict_img_self_attn: bool, + ) -> torch.Tensor: + device = TorchDevice.choose_torch_device() + + # Infer txt_seq_len from the t5_embeddings tensor. + txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # Decide whether to compute the img self-attention region mask. + # When compute_img_self_attn_region_mask is True, img self attention is only allowed within regions. + # When compute_img_self_attn_region_mask is False, img self attention is not constrained. + has_region_masks = any(mask is not None for mask in regional_text_conditioning.image_masks) + compute_img_self_attn_region_mask = restrict_img_self_attn and has_region_masks # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. # Concatenation happens in the following order: [txt_seq, img_seq]. @@ -39,34 +74,38 @@ class RegionalPromptingExtension: # Initialize empty attention mask. regional_attention_mask = torch.zeros( - (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.bool + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 ) - for i in range(len(self.regional_text_conditioning.t5_embedding_ranges)): - image_mask = self.regional_text_conditioning.image_masks[0, i] - t5_embedding_range = self.regional_text_conditioning.t5_embedding_ranges[i] - + for image_mask, t5_embedding_range in zip( + regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + ): # 1. txt attends to itself regional_attention_mask[ t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end - ] = True + ] = 1.0 # 2. txt attends to corresponding regional img # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( - 1, img_seq_len - ) + fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value # 3. regional img attends to corresponding txt # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( - img_seq_len, 1 - ) + fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value # 4. regional img attends to itself - # image_mask = image_mask.view(img_seq_len, 1) - # regional_attention_mask[txt_seq_len:, txt_seq_len:] = image_mask @ image_mask.T - regional_attention_mask[txt_seq_len:, txt_seq_len:] = True + if compute_img_self_attn_region_mask and image_mask is not None: + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + + if not compute_img_self_attn_region_mask: + # Allow unrestricted img self attention. + regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + + # Convert attention mask to boolean. + regional_attention_mask = regional_attention_mask > 0.5 return regional_attention_mask @@ -77,8 +116,8 @@ class RegionalPromptingExtension: ) -> FluxRegionalTextConditioning: """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" concat_t5_embeddings: list[torch.Tensor] = [] - concat_image_masks: list[torch.Tensor] = [] concat_t5_embedding_ranges: list[Range] = [] + image_masks: list[torch.Tensor | None] = [] # Choose global CLIP embedding. # Use the first global prompt's CLIP embedding as the global CLIP embedding. If there is no global prompt, use @@ -97,7 +136,7 @@ class RegionalPromptingExtension: Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) ) - concat_image_masks.append(text_conditioning.mask) + image_masks.append(text_conditioning.mask) cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] @@ -113,7 +152,7 @@ class RegionalPromptingExtension: t5_embeddings=t5_embeddings, clip_embeddings=global_clip_embedding, t5_txt_ids=t5_txt_ids, - image_masks=torch.cat(concat_image_masks, dim=1), + image_masks=image_masks, t5_embedding_ranges=concat_t5_embedding_ranges, ) diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py index fdfa884444..5bc8b4d041 100644 --- a/invokeai/backend/flux/text_conditioning.py +++ b/invokeai/backend/flux/text_conditioning.py @@ -9,7 +9,8 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range class FluxTextConditioning: t5_embeddings: torch.Tensor clip_embeddings: torch.Tensor - mask: torch.Tensor + # If mask is None, the prompt is a global prompt. + mask: torch.Tensor | None @dataclass @@ -24,10 +25,11 @@ class FluxRegionalTextConditioning: # Shape: (1, 768) clip_embeddings: torch.Tensor - # A binary mask indicating the regions of the image that the prompt should be applied to. - # Shape: (1, num_prompts, image_seq_len) - # Dtype: torch.bool - image_masks: torch.Tensor + # A binary mask indicating the regions of the image that the prompt should be applied to. If None, the prompt is a + # global prompt. + # image_masks[i] is the mask for the ith prompt. + # image_masks[i] has shape (1, image_seq_len) and dtype torch.bool. + image_masks: list[torch.Tensor | None] # List of ranges that represent the embedding ranges for each mask. # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i]. From e01f66b02633c741ad4b050dd0dac73324788951 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 25 Nov 2024 22:40:08 +0000 Subject: [PATCH 09/28] Apply regional attention masks in the single stream blocks in addition to the double stream blocks. --- .../backend/flux/custom_block_processor.py | 56 ++++++++++++++++--- .../regional_prompting_extension.py | 6 ++ invokeai/backend/flux/model.py | 18 +++++- 3 files changed, 69 insertions(+), 11 deletions(-) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index 8af4912856..dd180eb56c 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -4,7 +4,7 @@ import torch from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.math import attention -from invokeai.backend.flux.modules.layers import DoubleStreamBlock +from invokeai.backend.flux.modules.layers import DoubleStreamBlock, SingleStreamBlock class CustomDoubleStreamBlockProcessor: @@ -74,14 +74,9 @@ class CustomDoubleStreamBlockProcessor: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support """ - + attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index) img, txt, img_q = CustomDoubleStreamBlockProcessor._double_stream_block_forward( - block, - img, - txt, - vec, - pe, - attn_mask=regional_prompting_extension.attn_mask_with_unrestricted_img_self_attn, + block, img, txt, vec, pe, attn_mask=attn_mask ) # Apply IP-Adapter conditioning. @@ -96,3 +91,48 @@ class CustomDoubleStreamBlockProcessor: ) return img, txt + + +class CustomSingleStreamBlockProcessor: + """A class containing a custom implementation of SingleStreamBlock.forward() with additional features (masking, + etc.) + """ + + @staticmethod + def _single_stream_block_forward( + block: SingleStreamBlock, + x: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + attn_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """This function is a direct copy of SingleStreamBlock.forward().""" + mod, _ = block.modulation(vec) + x_mod = (1 + mod.scale) * block.pre_norm(x) + mod.shift + qkv, mlp = torch.split(block.linear1(x_mod), [3 * block.hidden_size, block.mlp_hidden_dim], dim=-1) + + q, k, v = einops.rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=block.num_heads) + q, k = block.norm(q, k, v) + + # compute attention + attn = attention(q, k, v, pe=pe, attn_mask=attn_mask) + # compute activation in mlp stream, cat again and run second linear layer + output = block.linear2(torch.cat((attn, block.mlp_act(mlp)), 2)) + return x + mod.gate * output + + @staticmethod + def custom_single_block_forward( + timestep_index: int, + total_num_timesteps: int, + block_index: int, + block: SingleStreamBlock, + img: torch.Tensor, + vec: torch.Tensor, + pe: torch.Tensor, + regional_prompting_extension: RegionalPromptingExtension, + ) -> torch.Tensor: + """A custom implementation of SingleStreamBlock.forward() with additional features: + - Masking + """ + attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index) + return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 1986c7f4aa..259db7c29a 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -25,6 +25,12 @@ class RegionalPromptingExtension: self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn + def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: + return self.attn_mask_with_unrestricted_img_self_attn + + def get_single_stream_attn_mask(self) -> torch.Tensor | None: + return self.attn_mask_with_unrestricted_img_self_attn + @classmethod def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int): """Create a RegionalPromptingExtension from a list of text conditionings. diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index a0c44f0039..0add6fd4d7 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -5,7 +5,10 @@ from dataclasses import dataclass import torch from torch import Tensor, nn -from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor +from invokeai.backend.flux.custom_block_processor import ( + CustomDoubleStreamBlockProcessor, + CustomSingleStreamBlockProcessor, +) from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.modules.layers import ( @@ -119,7 +122,6 @@ class Flux(nn.Module): assert len(controlnet_double_block_residuals) == len(self.double_blocks) for block_index, block in enumerate(self.double_blocks): assert isinstance(block, DoubleStreamBlock) - img, txt = CustomDoubleStreamBlockProcessor.custom_double_block_forward( timestep_index=timestep_index, total_num_timesteps=total_num_timesteps, @@ -143,7 +145,17 @@ class Flux(nn.Module): assert len(controlnet_single_block_residuals) == len(self.single_blocks) for block_index, block in enumerate(self.single_blocks): - img = block(img, vec=vec, pe=pe) + assert isinstance(block, SingleStreamBlock) + img = CustomSingleStreamBlockProcessor.custom_single_block_forward( + timestep_index=timestep_index, + total_num_timesteps=total_num_timesteps, + block_index=block_index, + block=block, + img=img, + vec=vec, + pe=pe, + regional_prompting_extension=regional_prompting_extension, + ) if controlnet_single_block_residuals is not None: img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index] From faee79dc95cfb6e1b7727d44bb9f1e2d18bc6956 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 26 Nov 2024 16:55:52 +0000 Subject: [PATCH 10/28] Distinguish between restricted and unrestricted attn masks in FLUX regional prompting. --- .../backend/flux/custom_block_processor.py | 2 +- .../regional_prompting_extension.py | 130 ++++++++++++++---- 2 files changed, 102 insertions(+), 30 deletions(-) diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index dd180eb56c..0f56adacde 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -134,5 +134,5 @@ class CustomSingleStreamBlockProcessor: """A custom implementation of SingleStreamBlock.forward() with additional features: - Masking """ - attn_mask = regional_prompting_extension.get_double_stream_attn_mask(block_index) + attn_mask = regional_prompting_extension.get_single_stream_attn_mask(block_index) return CustomSingleStreamBlockProcessor._single_stream_block_forward(block, img, vec, pe, attn_mask=attn_mask) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 259db7c29a..7d51e12508 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -18,18 +18,20 @@ class RegionalPromptingExtension: def __init__( self, regional_text_conditioning: FluxRegionalTextConditioning, - attn_mask_with_restricted_img_self_attn: torch.Tensor | None = None, - attn_mask_with_unrestricted_img_self_attn: torch.Tensor | None = None, + restricted_attn_mask: torch.Tensor | None = None, + # unrestricted_attn_mask: torch.Tensor | None = None, ): self.regional_text_conditioning = regional_text_conditioning - self.attn_mask_with_restricted_img_self_attn = attn_mask_with_restricted_img_self_attn - self.attn_mask_with_unrestricted_img_self_attn = attn_mask_with_unrestricted_img_self_attn + self.restricted_attn_mask = restricted_attn_mask + # self.unrestricted_attn_mask = unrestricted_attn_mask def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: - return self.attn_mask_with_unrestricted_img_self_attn + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] - def get_single_stream_attn_mask(self) -> torch.Tensor | None: - return self.attn_mask_with_unrestricted_img_self_attn + def get_single_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: + order = [self.restricted_attn_mask, None] + return order[block_index % len(order)] @classmethod def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning], img_seq_len: int): @@ -40,37 +42,34 @@ class RegionalPromptingExtension: img_seq_len (int): The image sequence length (i.e. packed_height * packed_width). """ regional_text_conditioning = cls._concat_regional_text_conditioning(text_conditioning) - attn_mask_with_restricted_img_self_attn = cls._prepare_attn_mask( - regional_text_conditioning, img_seq_len, restrict_img_self_attn=True - ) - attn_mask_with_unrestricted_img_self_attn = cls._prepare_attn_mask( - regional_text_conditioning, img_seq_len, restrict_img_self_attn=False + attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask( + regional_text_conditioning, img_seq_len ) + # attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask( + # regional_text_conditioning, img_seq_len + # ) return cls( regional_text_conditioning=regional_text_conditioning, - attn_mask_with_restricted_img_self_attn=attn_mask_with_restricted_img_self_attn, - attn_mask_with_unrestricted_img_self_attn=attn_mask_with_unrestricted_img_self_attn, + restricted_attn_mask=attn_mask_with_restricted_img_self_attn, + # unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn, ) @classmethod - def _prepare_attn_mask( + def _prepare_unrestricted_attn_mask( cls, regional_text_conditioning: FluxRegionalTextConditioning, img_seq_len: int, - restrict_img_self_attn: bool, ) -> torch.Tensor: + """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: + - img self-attention is not masked. + - img regions attend to both txt within their own region and to global prompts. + """ device = TorchDevice.choose_torch_device() # Infer txt_seq_len from the t5_embeddings tensor. txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] - # Decide whether to compute the img self-attention region mask. - # When compute_img_self_attn_region_mask is True, img self attention is only allowed within regions. - # When compute_img_self_attn_region_mask is False, img self attention is not constrained. - has_region_masks = any(mask is not None for mask in regional_text_conditioning.image_masks) - compute_img_self_attn_region_mask = restrict_img_self_attn and has_region_masks - - # In the double stream attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. # Concatenation happens in the following order: [txt_seq, img_seq]. # There are 4 portions of the attention mask to consider as we prepare it: # 1. txt attends to itself @@ -101,14 +100,87 @@ class RegionalPromptingExtension: fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value - # 4. regional img attends to itself - if compute_img_self_attn_region_mask and image_mask is not None: - image_mask = image_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + # 4. regional img attends to itself + # Allow unrestricted img self attention. + regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 - if not compute_img_self_attn_region_mask: - # Allow unrestricted img self attention. + # Convert attention mask to boolean. + regional_attention_mask = regional_attention_mask > 0.5 + + return regional_attention_mask + + @classmethod + def _prepare_restricted_attn_mask( + cls, + regional_text_conditioning: FluxRegionalTextConditioning, + img_seq_len: int, + ) -> torch.Tensor: + """Prepare a 'restricted' attention mask. In this context, 'restricted' means that: + - img self-attention is only allowed within regions. + - img regions only attend to txt within their own region, not to global prompts. + + """ + device = TorchDevice.choose_torch_device() + + # Infer txt_seq_len from the t5_embeddings tensor. + txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + + # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # Concatenation happens in the following order: [txt_seq, img_seq]. + # There are 4 portions of the attention mask to consider as we prepare it: + # 1. txt attends to itself + # 2. txt attends to corresponding regional img + # 3. regional img attends to corresponding txt + # 4. regional img attends to itself + + # Initialize empty attention mask. + regional_attention_mask = torch.zeros( + (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + ) + + # Identify background region. I.e. the region that is not covered by any region masks. + background_region_mask: None | torch.Tensor = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + if background_region_mask is None: + background_region_mask = torch.ones_like(image_mask) + background_region_mask *= 1 - image_mask + + for image_mask, t5_embedding_range in zip( + regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + ): + # 1. txt attends to itself + regional_attention_mask[ + t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + ] = 1.0 + + if image_mask is None: + continue + + # 2. txt attends to corresponding regional img + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( + 1, img_seq_len + ) + + # 3. regional img attends to corresponding txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( + img_seq_len, 1 + ) + + # 4. regional img attends to itself + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + + # Handle image background regions. + if background_region_mask is None: + # There are no region masks, so allow unrestricted img self attention. regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + else: + # Allow background regions to attend to themselves and to the rest of the image. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5 From b54463d29427ae1b5969ded96a3016e02d6b5753 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 26 Nov 2024 17:57:31 +0000 Subject: [PATCH 11/28] Allow regional prompting background regions to attend to themselves and to the entire txt embedding. --- .../flux/extensions/regional_prompting_extension.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 7d51e12508..c3eb8e542f 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -175,12 +175,13 @@ class RegionalPromptingExtension: # Handle image background regions. if background_region_mask is None: - # There are no region masks, so allow unrestricted img self attention. - regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + # There are no region masks, so allow unrestricted img-img attention, and unrestricted img-txt attention. + regional_attention_mask[txt_seq_len:, :] = 1.0 + regional_attention_mask[:, txt_seq_len:] = 1.0 else: - # Allow background regions to attend to themselves and to the rest of the image. - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + # Allow background regions to attend to themselves and to the entire txt embedding. + regional_attention_mask[txt_seq_len:, :] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5 From 9a7b000995e513da0231bc1808fa460dfd3796d2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 17:04:35 +0000 Subject: [PATCH 12/28] Update frontend to support regional prompts with FLUX in the canvas. --- .../components/CanvasAddEntityButtons.tsx | 2 +- .../EntityListGlobalActionBarAddLayerMenu.tsx | 2 +- .../nodes/util/graph/generation/addRegions.ts | 166 ++++++++++++------ .../util/graph/generation/buildFLUXGraph.ts | 30 +++- .../frontend/web/src/services/api/schema.ts | 16 +- 5 files changed, 149 insertions(+), 67 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx index ed2bb86a88..c4462c0f4d 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasAddEntityButtons.tsx @@ -63,7 +63,7 @@ export const CanvasAddEntityButtons = memo(() => { justifyContent="flex-start" leftIcon={} onClick={addRegionalGuidance} - isDisabled={isFLUX || isSD3} + isDisabled={isSD3} > {t('controlLayers.regionalGuidance')} diff --git a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx index 70623f54b7..40c750bc52 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/CanvasEntityList/EntityListGlobalActionBarAddLayerMenu.tsx @@ -49,7 +49,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => { } onClick={addInpaintMask}> {t('controlLayers.inpaintMask')} - } onClick={addRegionalGuidance} isDisabled={isFLUX || isSD3}> + } onClick={addRegionalGuidance} isDisabled={isSD3}> {t('controlLayers.regionalGuidance')} } onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}> diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index dcce2046da..2e0495e47c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -50,14 +50,15 @@ export const addRegions = async ( g: Graph, bbox: Rect, base: BaseModelType, - denoise: Invocation<'denoise_latents'>, - posCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, - negCond: Invocation<'compel'> | Invocation<'sdxl_compel_prompt'>, + denoise: Invocation<'denoise_latents' | 'flux_denoise'>, + posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, + negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null, posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'>, + negCondCollect: Invocation<'collect'> | null, ipAdapterCollect: Invocation<'collect'> ): Promise => { const isSDXL = base === 'sdxl'; + const isFLUX = base === 'flux'; const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const results: AddedRegionResult[] = []; @@ -94,20 +95,27 @@ export const addRegions = async ( if (region.positivePrompt) { // The main positive conditioning node result.addedPositivePrompt = true; - const regionalPosCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond'), - prompt: region.positivePrompt, - } - ); + let regionalPosCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + style: region.positivePrompt, // TODO: Should we put the positive prompt in both fields? + }); + } else if (isFLUX) { + regionalPosCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond'), + prompt: region.positivePrompt, + }); + } // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalPosCond, 'mask'); // Connect the conditioning to the collector @@ -115,38 +123,55 @@ export const addRegions = async ( // Copy the connections to the "global" positive conditioning node to the regional cond if (posCond.type === 'compel') { for (const edge of g.getEdgesTo(posCond, ['clip', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'sdxl_compel_prompt') { + for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCond.id; + g.addEdgeFromObj(clone); + } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCond.id; g.addEdgeFromObj(clone); } } else { - for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { - // Clone the edge, but change the destination node to the regional conditioning node - const clone = deepClone(edge); - clone.destination.node_id = regionalPosCond.id; - g.addEdgeFromObj(clone); - } + assert(false, 'Unsupported positive conditioning node type.'); } } if (region.negativePrompt) { - result.addedNegativePrompt = true; + assert(negCond, 'Negative conditioning node is required if there is a negative prompt'); + assert(negCondCollect, 'Negative conditioning collector is required if there is a negative prompt'); + // The main negative conditioning node - const regionalNegCond = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - style: region.negativePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_negative_cond'), - prompt: region.negativePrompt, - } - ); + result.addedNegativePrompt = true; + let regionalNegCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalNegCond = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + style: region.negativePrompt, + }); + } else if (isFLUX) { + regionalNegCond = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } else { + regionalNegCond = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_negative_cond'), + prompt: region.negativePrompt, + }); + } + // Connect the mask to the conditioning g.addEdge(maskToTensor, 'mask', regionalNegCond, 'mask'); // Connect the conditioning to the collector @@ -158,17 +183,27 @@ export const addRegions = async ( clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } - } else { + } else if (negCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(negCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalNegCond.id; g.addEdgeFromObj(clone); } + } else if (negCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(negCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalNegCond.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported negative conditioning node type.'); } } // If we are using the "invert" auto-negative setting, we need to add an additional negative conditioning node if (region.autoNegative && region.positivePrompt) { + assert(negCondCollect, 'Negative conditioning collector is required if there is an auto-negative setting'); + result.addedAutoNegativePositivePrompt = true; // We re-use the mask image, but invert it when converting to tensor const invertTensorMask = g.addNode({ @@ -178,20 +213,27 @@ export const addRegions = async ( // Connect the OG mask image to the inverted mask-to-tensor node g.addEdge(maskToTensor, 'mask', invertTensorMask, 'mask'); // Create the conditioning node. It's going to be connected to the negative cond collector, but it uses the positive prompt - const regionalPosCondInverted = g.addNode( - isSDXL - ? { - type: 'sdxl_compel_prompt', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - style: region.positivePrompt, - } - : { - type: 'compel', - id: getPrefixedId('prompt_region_positive_cond_inverted'), - prompt: region.positivePrompt, - } - ); + let regionalPosCondInverted: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + if (isSDXL) { + regionalPosCondInverted = g.addNode({ + type: 'sdxl_compel_prompt', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + style: region.positivePrompt, + }); + } else if (isFLUX) { + regionalPosCondInverted = g.addNode({ + type: 'flux_text_encoder', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } else { + regionalPosCondInverted = g.addNode({ + type: 'compel', + id: getPrefixedId('prompt_region_positive_cond_inverted'), + prompt: region.positivePrompt, + }); + } // Connect the inverted mask to the conditioning g.addEdge(invertTensorMask, 'mask', regionalPosCondInverted, 'mask'); // Connect the conditioning to the negative collector @@ -203,12 +245,20 @@ export const addRegions = async ( clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } - } else { + } else if (posCond.type === 'sdxl_compel_prompt') { for (const edge of g.getEdgesTo(posCond, ['clip', 'clip2', 'mask'])) { const clone = deepClone(edge); clone.destination.node_id = regionalPosCondInverted.id; g.addEdgeFromObj(clone); } + } else if (posCond.type === 'flux_text_encoder') { + for (const edge of g.getEdgesTo(posCond, ['clip', 't5_encoder', 't5_max_seq_len', 'mask'])) { + const clone = deepClone(edge); + clone.destination.node_id = regionalPosCondInverted.id; + g.addEdgeFromObj(clone); + } + } else { + assert(false, 'Unsupported positive conditioning node type.'); } } @@ -217,6 +267,8 @@ export const addRegions = async ( ); for (const { id, ipAdapter } of validRGIPAdapters) { + assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.'); + result.addedIPAdapters++; const { weight, model, clipVisionModel, method, beginEndStepPct, image } = ipAdapter; assert(model, 'IP Adapter model is required'); @@ -250,7 +302,7 @@ export const addRegions = async ( }; const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image + // Must be a model that matches the current base and must have a control image const hasModel = Boolean(ipAdapter.model); const modelMatchesBase = ipAdapter.model?.base === base; const hasImage = Boolean(ipAdapter.image); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index d893760f3c..22be9f58e7 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -11,6 +11,7 @@ import { addImageToImage } from 'features/nodes/util/graph/generation/addImageTo import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; import { addOutpaint } from 'features/nodes/util/graph/generation/addOutpaint'; +import { addRegions } from 'features/nodes/util/graph/generation/addRegions'; import { addTextToImage } from 'features/nodes/util/graph/generation/addTextToImage'; import { addWatermarker } from 'features/nodes/util/graph/generation/addWatermarker'; import { Graph } from 'features/nodes/util/graph/generation/Graph'; @@ -79,7 +80,10 @@ export const buildFLUXGraph = async ( id: getPrefixedId('flux_text_encoder'), prompt: positivePrompt, }); - + const posCondCollect = g.addNode({ + type: 'collect', + id: getPrefixedId('pos_cond_collect'), + }); const denoise = g.addNode({ type: 'flux_denoise', id: getPrefixedId('flux_denoise'), @@ -104,13 +108,12 @@ export const buildFLUXGraph = async ( g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder'); g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len'); + g.addEdge(posCond, 'conditioning', posCondCollect, 'item'); + g.addEdge(posCondCollect, 'collection', denoise, 'positive_text_conditioning'); + g.addEdge(denoise, 'latents', l2i, 'latents'); addFLUXLoRAs(state, g, denoise, modelLoader, posCond); - g.addEdge(posCond, 'conditioning', denoise, 'positive_text_conditioning'); - - g.addEdge(denoise, 'latents', l2i, 'latents'); - const modelConfig = await fetchModelConfigWithTypeGuard(model.key, isNonRefinerMainModelConfig); assert(modelConfig.base === 'flux'); @@ -216,7 +219,22 @@ export const buildFLUXGraph = async ( }); const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); - const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters; + const regionsResult = await addRegions( + manager, + canvas.regionalGuidance.entities, + g, + canvas.bbox.rect, + modelConfig.base, + denoise, + posCond, + null, + posCondCollect, + null, + ipAdapterCollector + ); + + const totalIPAdaptersAdded = + ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); } else { diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index 90adf7517e..8878a4bcfa 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -6564,6 +6564,11 @@ export type components = { * @description The name of conditioning tensor */ conditioning_name: string; + /** + * @description The mask associated with this conditioning tensor. Excluded regions should be set to False, included regions should be set to True. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; }; /** * FluxConditioningOutput @@ -6771,15 +6776,17 @@ export type components = { */ transformer?: components["schemas"]["TransformerField"]; /** + * Positive Text Conditioning * @description Positive conditioning tensor * @default null */ - positive_text_conditioning?: components["schemas"]["FluxConditioningField"]; + positive_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][]; /** + * Negative Text Conditioning * @description Negative conditioning tensor. Can be None if cfg_scale is 1.0. * @default null */ - negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | null; + negative_text_conditioning?: components["schemas"]["FluxConditioningField"] | components["schemas"]["FluxConditioningField"][] | null; /** * CFG Scale * @description Classifier-Free Guidance scale @@ -7133,6 +7140,11 @@ export type components = { * @default null */ prompt?: string; + /** + * @description A mask defining the region that this conditioning prompt applies to. + * @default null + */ + mask?: components["schemas"]["TensorField"] | null; /** * type * @default flux_text_encoder From fa5653cdf7adc7b902d7a08c6ac573c0b8631516 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 17:08:42 +0000 Subject: [PATCH 13/28] Remove unused 'denoise' param to addRegions(). --- .../web/src/features/nodes/util/graph/generation/addRegions.ts | 2 -- .../src/features/nodes/util/graph/generation/buildFLUXGraph.ts | 1 - .../src/features/nodes/util/graph/generation/buildSD1Graph.ts | 1 - .../src/features/nodes/util/graph/generation/buildSDXLGraph.ts | 1 - 4 files changed, 5 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index 2e0495e47c..cdeb30a6f6 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -35,7 +35,6 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => * @param regions Array of regions to add * @param g The graph to add the layers to * @param base The base model type - * @param denoise The main denoise node * @param posCond The positive conditioning node * @param negCond The negative conditioning node * @param posCondCollect The positive conditioning collector @@ -50,7 +49,6 @@ export const addRegions = async ( g: Graph, bbox: Rect, base: BaseModelType, - denoise: Invocation<'denoise_latents' | 'flux_denoise'>, posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null, posCondCollect: Invocation<'collect'>, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 22be9f58e7..b1e5c60c2a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -225,7 +225,6 @@ export const buildFLUXGraph = async ( g, canvas.bbox.rect, modelConfig.base, - denoise, posCond, null, posCondCollect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 4008fd05d6..69e145cbe2 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -271,7 +271,6 @@ export const buildSD1Graph = async ( g, canvas.bbox.rect, modelConfig.base, - denoise, posCond, negCond, posCondCollect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 37ab697522..8d7fb67c10 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -276,7 +276,6 @@ export const buildSDXLGraph = async ( g, canvas.bbox.rect, modelConfig.base, - denoise, posCond, negCond, posCondCollect, From e9701851612d852b3e64c1ff1e923723f2896180 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:13:07 +0000 Subject: [PATCH 14/28] Tweak flux regional prompting attention scheme based on latest experimentation. --- .../regional_prompting_extension.py | 55 +++++++++++++------ 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index c3eb8e542f..f5f203af69 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -154,24 +154,43 @@ class RegionalPromptingExtension: t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end ] = 1.0 - if image_mask is None: - continue + if image_mask is not None: + # 2. txt attends to corresponding regional img + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + image_mask.view(1, img_seq_len) + ) - # 2. txt attends to corresponding regional img - # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = image_mask.view( - 1, img_seq_len - ) + # 3. regional img attends to corresponding txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + image_mask.view(img_seq_len, 1) + ) - # 3. regional img attends to corresponding txt - # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = image_mask.view( - img_seq_len, 1 - ) + # 4. regional img attends to itself + image_mask = image_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + else: + if background_region_mask is None: + # There are no region masks, so we don't need to do anything here - this case is handled below. + continue - # 4. regional img attends to itself - image_mask = image_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T + # We don't allow attention between non-background image regions and global prompts. This helps to ensure + # that regions focus on their local prompts. We do, however, allow attention between background regions + # and global prompts. If we didn't do this, then the background regions would not attend to any txt + # embeddings, which we found experimentally to cause artifacts. + + # 2. global txt attends to background region + # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = ( + background_region_mask.view(1, img_seq_len) + ) + + # 3. background region attends to global txt + # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = ( + background_region_mask.view(img_seq_len, 1) + ) # Handle image background regions. if background_region_mask is None: @@ -179,9 +198,9 @@ class RegionalPromptingExtension: regional_attention_mask[txt_seq_len:, :] = 1.0 regional_attention_mask[:, txt_seq_len:] = 1.0 else: - # Allow background regions to attend to themselves and to the entire txt embedding. - regional_attention_mask[txt_seq_len:, :] += background_region_mask.view(img_seq_len, 1) - regional_attention_mask[:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + # Allow background regions to attend to themselves. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5 From 3ebd8d6c07833ad2ee273f28eb1d9bc5cd6186b1 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:13:25 +0000 Subject: [PATCH 15/28] Delete outdated TODO comment. --- invokeai/backend/flux/denoise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index ae541ba7d8..9e1305e467 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -50,7 +50,6 @@ def denoise( # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: - # TODO(ryand): Think about how to handle regional prompting with ControlNet. controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step_index, From 6565cea039be1b808c7055622584a2361eff6861 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:16:44 +0000 Subject: [PATCH 16/28] Comment unused _prepare_unrestricted_attn_mask(...) for future reference. --- .../regional_prompting_extension.py | 100 +++++++++--------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index f5f203af69..834463137a 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -12,18 +12,16 @@ from invokeai.backend.util.mask import to_standard_float_mask class RegionalPromptingExtension: """A class for managing regional prompting with FLUX. - Implementation inspired by: https://arxiv.org/pdf/2411.02395 + This implementation is inspired by https://arxiv.org/pdf/2411.02395 (though there are significant differences). """ def __init__( self, regional_text_conditioning: FluxRegionalTextConditioning, restricted_attn_mask: torch.Tensor | None = None, - # unrestricted_attn_mask: torch.Tensor | None = None, ): self.regional_text_conditioning = regional_text_conditioning self.restricted_attn_mask = restricted_attn_mask - # self.unrestricted_attn_mask = unrestricted_attn_mask def get_double_stream_attn_mask(self, block_index: int) -> torch.Tensor | None: order = [self.restricted_attn_mask, None] @@ -45,69 +43,67 @@ class RegionalPromptingExtension: attn_mask_with_restricted_img_self_attn = cls._prepare_restricted_attn_mask( regional_text_conditioning, img_seq_len ) - # attn_mask_with_unrestricted_img_self_attn = cls._prepare_unrestricted_attn_mask( - # regional_text_conditioning, img_seq_len - # ) return cls( regional_text_conditioning=regional_text_conditioning, restricted_attn_mask=attn_mask_with_restricted_img_self_attn, - # unrestricted_attn_mask=attn_mask_with_unrestricted_img_self_attn, ) - @classmethod - def _prepare_unrestricted_attn_mask( - cls, - regional_text_conditioning: FluxRegionalTextConditioning, - img_seq_len: int, - ) -> torch.Tensor: - """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: - - img self-attention is not masked. - - img regions attend to both txt within their own region and to global prompts. - """ - device = TorchDevice.choose_torch_device() + # Keeping _prepare_unrestricted_attn_mask for reference as an alternative masking strategy: + # + # @classmethod + # def _prepare_unrestricted_attn_mask( + # cls, + # regional_text_conditioning: FluxRegionalTextConditioning, + # img_seq_len: int, + # ) -> torch.Tensor: + # """Prepare an 'unrestricted' attention mask. In this context, 'unrestricted' means that: + # - img self-attention is not masked. + # - img regions attend to both txt within their own region and to global prompts. + # """ + # device = TorchDevice.choose_torch_device() - # Infer txt_seq_len from the t5_embeddings tensor. - txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] + # # Infer txt_seq_len from the t5_embeddings tensor. + # txt_seq_len = regional_text_conditioning.t5_embeddings.shape[1] - # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. - # Concatenation happens in the following order: [txt_seq, img_seq]. - # There are 4 portions of the attention mask to consider as we prepare it: - # 1. txt attends to itself - # 2. txt attends to corresponding regional img - # 3. regional img attends to corresponding txt - # 4. regional img attends to itself + # # In the attention blocks, the txt seq and img seq are concatenated and then attention is applied. + # # Concatenation happens in the following order: [txt_seq, img_seq]. + # # There are 4 portions of the attention mask to consider as we prepare it: + # # 1. txt attends to itself + # # 2. txt attends to corresponding regional img + # # 3. regional img attends to corresponding txt + # # 4. regional img attends to itself - # Initialize empty attention mask. - regional_attention_mask = torch.zeros( - (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 - ) + # # Initialize empty attention mask. + # regional_attention_mask = torch.zeros( + # (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 + # ) - for image_mask, t5_embedding_range in zip( - regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True - ): - # 1. txt attends to itself - regional_attention_mask[ - t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end - ] = 1.0 + # for image_mask, t5_embedding_range in zip( + # regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True + # ): + # # 1. txt attends to itself + # regional_attention_mask[ + # t5_embedding_range.start : t5_embedding_range.end, t5_embedding_range.start : t5_embedding_range.end + # ] = 1.0 - # 2. txt attends to corresponding regional img - # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. - fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 - regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value + # # 2. txt attends to corresponding regional img + # # Note that we reshape to (1, img_seq_len) to ensure broadcasting works as desired. + # fill_value = image_mask.view(1, img_seq_len) if image_mask is not None else 1.0 + # regional_attention_mask[t5_embedding_range.start : t5_embedding_range.end, txt_seq_len:] = fill_value - # 3. regional img attends to corresponding txt - # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. - fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 - regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value + # # 3. regional img attends to corresponding txt + # # Note that we reshape to (img_seq_len, 1) to ensure broadcasting works as desired. + # fill_value = image_mask.view(img_seq_len, 1) if image_mask is not None else 1.0 + # regional_attention_mask[txt_seq_len:, t5_embedding_range.start : t5_embedding_range.end] = fill_value - # 4. regional img attends to itself - # Allow unrestricted img self attention. - regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 + # # 4. regional img attends to itself + # # Allow unrestricted img self attention. + # regional_attention_mask[txt_seq_len:, txt_seq_len:] = 1.0 - # Convert attention mask to boolean. - regional_attention_mask = regional_attention_mask > 0.5 + # # Convert attention mask to boolean. + # regional_attention_mask = regional_attention_mask > 0.5 - return regional_attention_mask + # return regional_attention_mask @classmethod def _prepare_restricted_attn_mask( From 64364e791184447fa695831e49392cffd0b5ab35 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 27 Nov 2024 22:40:10 +0000 Subject: [PATCH 17/28] Short-circuit if there are no region masks in FLUX and don't apply attention masking. --- .../regional_prompting_extension.py | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py index 834463137a..f1086c3286 100644 --- a/invokeai/backend/flux/extensions/regional_prompting_extension.py +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -110,12 +110,25 @@ class RegionalPromptingExtension: cls, regional_text_conditioning: FluxRegionalTextConditioning, img_seq_len: int, - ) -> torch.Tensor: + ) -> torch.Tensor | None: """Prepare a 'restricted' attention mask. In this context, 'restricted' means that: - img self-attention is only allowed within regions. - img regions only attend to txt within their own region, not to global prompts. - """ + # Identify background region. I.e. the region that is not covered by any region masks. + background_region_mask: None | torch.Tensor = None + for image_mask in regional_text_conditioning.image_masks: + if image_mask is not None: + if background_region_mask is None: + background_region_mask = torch.ones_like(image_mask) + background_region_mask *= 1 - image_mask + + if background_region_mask is None: + # There are no region masks, short-circuit and return None. + # TODO(ryand): We could restrict txt-txt attention across multiple global prompts, but this would + # is a rare use case and would make the logic here significantly more complicated. + return None + device = TorchDevice.choose_torch_device() # Infer txt_seq_len from the t5_embeddings tensor. @@ -134,14 +147,6 @@ class RegionalPromptingExtension: (txt_seq_len + img_seq_len, txt_seq_len + img_seq_len), device=device, dtype=torch.float16 ) - # Identify background region. I.e. the region that is not covered by any region masks. - background_region_mask: None | torch.Tensor = None - for image_mask in regional_text_conditioning.image_masks: - if image_mask is not None: - if background_region_mask is None: - background_region_mask = torch.ones_like(image_mask) - background_region_mask *= 1 - image_mask - for image_mask, t5_embedding_range in zip( regional_text_conditioning.image_masks, regional_text_conditioning.t5_embedding_ranges, strict=True ): @@ -167,10 +172,6 @@ class RegionalPromptingExtension: image_mask = image_mask.view(img_seq_len, 1) regional_attention_mask[txt_seq_len:, txt_seq_len:] += image_mask @ image_mask.T else: - if background_region_mask is None: - # There are no region masks, so we don't need to do anything here - this case is handled below. - continue - # We don't allow attention between non-background image regions and global prompts. This helps to ensure # that regions focus on their local prompts. We do, however, allow attention between background regions # and global prompts. If we didn't do this, then the background regions would not attend to any txt @@ -188,15 +189,9 @@ class RegionalPromptingExtension: background_region_mask.view(img_seq_len, 1) ) - # Handle image background regions. - if background_region_mask is None: - # There are no region masks, so allow unrestricted img-img attention, and unrestricted img-txt attention. - regional_attention_mask[txt_seq_len:, :] = 1.0 - regional_attention_mask[:, txt_seq_len:] = 1.0 - else: - # Allow background regions to attend to themselves. - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) - regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) + # Allow background regions to attend to themselves. + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(img_seq_len, 1) + regional_attention_mask[txt_seq_len:, txt_seq_len:] += background_region_mask.view(1, img_seq_len) # Convert attention mask to boolean. regional_attention_mask = regional_attention_mask > 0.5 From 5d8dd6e26ec5111ba1da8fd157501ef9236ca20c Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 28 Nov 2024 18:49:29 +0000 Subject: [PATCH 18/28] Fix FLUX regional negative prompts. --- invokeai/backend/flux/denoise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index 9e1305e467..66e3984edb 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -109,6 +109,7 @@ def denoise( controlnet_double_block_residuals=None, controlnet_single_block_residuals=None, ip_adapter_extensions=neg_ip_adapter_extensions, + regional_prompting_extension=neg_regional_prompting_extension, ) pred = neg_pred + step_cfg_scale * (pred - neg_pred) From c276b60af965a0d4e53b34f1ab5fb5cd3bcd3468 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Thu, 28 Nov 2024 14:26:24 +1000 Subject: [PATCH 19/28] tidy(ui): use object for addRegions graph builder util arg --- .../nodes/util/graph/generation/addRegions.ts | 39 +++++++++++++------ .../util/graph/generation/buildFLUXGraph.ts | 24 ++++++------ .../util/graph/generation/buildSD1Graph.ts | 20 +++++----- .../util/graph/generation/buildSDXLGraph.ts | 20 +++++----- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index cdeb30a6f6..1c058c5f4c 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -30,10 +30,25 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => return isEnabled && (hasTextPrompt || hasIPAdapter); }; +type AddRegionsArg = { + manager: CanvasManager; + regions: CanvasRegionalGuidanceState[]; + g: Graph; + bbox: Rect; + base: BaseModelType; + posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; + negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null; + posCondCollect: Invocation<'collect'>; + negCondCollect: Invocation<'collect'> | null; + ipAdapterCollect: Invocation<'collect'>; +}; + /** * Adds regional guidance to the graph + * @param manager The canvas manager * @param regions Array of regions to add * @param g The graph to add the layers to + * @param bbox The bounding box * @param base The base model type * @param posCond The positive conditioning node * @param negCond The negative conditioning node @@ -43,18 +58,18 @@ const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => * @returns A promise that resolves to the regions that were successfully added to the graph */ -export const addRegions = async ( - manager: CanvasManager, - regions: CanvasRegionalGuidanceState[], - g: Graph, - bbox: Rect, - base: BaseModelType, - posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>, - negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null, - posCondCollect: Invocation<'collect'>, - negCondCollect: Invocation<'collect'> | null, - ipAdapterCollect: Invocation<'collect'> -): Promise => { +export const addRegions = async ({ + manager, + regions, + g, + bbox, + base, + posCond, + negCond, + posCondCollect, + negCondCollect, + ipAdapterCollect, +}: AddRegionsArg): Promise => { const isSDXL = base === 'sdxl'; const isFLUX = base === 'flux'; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index b1e5c60c2a..4b3cad0774 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -213,31 +213,31 @@ export const buildFLUXGraph = async ( g.deleteNode(controlNetCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, - null, + negCond: null, posCondCollect, - null, - ipAdapterCollector - ); + negCondCollect: null, + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 69e145cbe2..ab38035b4a 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -259,31 +259,31 @@ export const buildSD1Graph = async ( g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 8d7fb67c10..4d84c025ec 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -264,31 +264,31 @@ export const buildSDXLGraph = async ( g.deleteNode(t2iAdapterCollector.id); } - const ipAdapterCollector = g.addNode({ + const ipAdapterCollect = g.addNode({ type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollector, modelConfig.base); + const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); - const regionsResult = await addRegions( + const regionsResult = await addRegions({ manager, - canvas.regionalGuidance.entities, + regions: canvas.regionalGuidance.entities, g, - canvas.bbox.rect, - modelConfig.base, + bbox: canvas.bbox.rect, + base: modelConfig.base, posCond, negCond, posCondCollect, negCondCollect, - ipAdapterCollector - ); + ipAdapterCollect, + }); const totalIPAdaptersAdded = ipAdapterResult.addedIPAdapters + regionsResult.reduce((acc, r) => acc + r.addedIPAdapters, 0); if (totalIPAdaptersAdded > 0) { - g.addEdge(ipAdapterCollector, 'collection', denoise, 'ip_adapter'); + g.addEdge(ipAdapterCollect, 'collection', denoise, 'ip_adapter'); } else { - g.deleteNode(ipAdapterCollector.id); + g.deleteNode(ipAdapterCollect.id); } if (state.system.shouldUseNSFWChecker) { From 484aaf1595df465bcebc9524284d22c1f78dd74b Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:12:32 +1000 Subject: [PATCH 20/28] feat(ui): add canvas layer validation utils These helpers consolidate layer validation checks. For example, checking that the layer has content drawn, is compatible with the selected main model, has valid reference images, etc. --- invokeai/frontend/web/public/locales/en.json | 2 +- .../controlLayers/store/validators.ts | 148 ++++++++++++++++++ 2 files changed, 149 insertions(+), 1 deletion(-) create mode 100644 invokeai/frontend/web/src/features/controlLayers/store/validators.ts diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index ef736b6dac..dfb75eb2a8 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1050,7 +1050,7 @@ "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", "rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters", - "rgNoRegion": "no region selected" + "emptyLayer": "empty layer" } }, "maskBlur": "Mask Blur", diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts new file mode 100644 index 0000000000..604ab3338b --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -0,0 +1,148 @@ +import type { + CanvasControlLayerState, + CanvasInpaintMaskState, + CanvasRasterLayerState, + CanvasReferenceImageState, + CanvasRegionalGuidanceState, +} from 'features/controlLayers/store/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { TFunction } from 'i18next'; + +export const getRegionalGuidanceWarnings = ( + entity: CanvasRegionalGuidanceState, + model: ParameterModel | null, + t: TFunction +): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push(t('parameters.invoke.layer.emptyLayer')); + } else { + if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) { + // Must have at least 1 prompt or IP Adapter + warnings.push(t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); + } + + if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push(t('controlLayers.invalidBaseModelType')); + } else if (model.base === 'flux') { + // Some features are not supported for flux models + if (entity.negativePrompt !== null) { + warnings.push(t('parameters.invoke.layer.rgNegativePromptNotSupported')); + } + if (entity.referenceImages.length > 0) { + warnings.push(t('parameters.invoke.layer.rgReferenceImagesNotSupported')); + } + if (entity.autoNegative) { + warnings.push(t('parameters.invoke.layer.rgAutoNegativeNotSupported')); + } + } else { + entity.referenceImages.forEach(({ ipAdapter }) => { + if (!ipAdapter.model) { + // No model selected + warnings.push(t('parameters.invoke.layer.ipAdapterNoModelSelected')); + } else if (ipAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push(t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); + } + + if (!ipAdapter.image) { + // No image selected + warnings.push(t('parameters.invoke.layer.ipAdapterNoImageSelected')); + } + }); + } + } + } + + return warnings; +}; + +export const getGlobalReferenceImageWarnings = ( + entity: CanvasReferenceImageState, + model: ParameterModel | null, + t: TFunction +): string[] => { + const warnings: string[] = []; + + if (!entity.ipAdapter.model) { + // No model selected + warnings.push(t('parameters.invoke.layer.ipAdapterNoModelSelected')); + } else if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push(t('controlLayers.invalidBaseModelType')); + } else if (entity.ipAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push(t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); + } + } + + if (!entity.ipAdapter.image) { + // No image selected + warnings.push(t('parameters.invoke.layer.ipAdapterNoImageSelected')); + } + + return warnings; +}; + +export const getControlLayerWarnings = ( + entity: CanvasControlLayerState, + model: ParameterModel | null, + t: TFunction +): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push(t('parameters.invoke.layer.emptyLayer')); + } else { + if (!entity.controlAdapter.model) { + // No model selected + warnings.push(t('parameters.invoke.layer.controlAdapterNoModelSelected')); + } else if (model) { + if (model.base === 'sd-3' || model.base === 'sd-2') { + // Unsupported model architecture + warnings.push(t('controlLayers.invalidBaseModelType')); + } else if (entity.controlAdapter.model.base !== model.base) { + // Supported model architecture but doesn't match + warnings.push(t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); + } + } + } + + return warnings; +}; + +export const getRasterLayerWarnings = ( + entity: CanvasRasterLayerState, + _model: ParameterModel | null, + t: TFunction +): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push(t('parameters.invoke.layer.emptyLayer')); + } + + return warnings; +}; + +export const getInpaintMaskWarnings = ( + entity: CanvasInpaintMaskState, + _model: ParameterModel | null, + t: TFunction +): string[] => { + const warnings: string[] = []; + + if (entity.objects.length === 0) { + // Layer is in empty state - skip other checks + warnings.push(t('parameters.invoke.layer.emptyLayer')); + } + + return warnings; +}; From 7dd33b0f39f5192de38919ad903ec07ef119c109 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:13:47 +1000 Subject: [PATCH 21/28] feat(ui): add indicator to canvas layer headers, displaying validation warnings If there are any issues with the layer, the icon is displayed. If the layer is disabled, the icon is greyed out but still visible. --- .../CanvasEntityHeaderCommonActions.tsx | 2 + .../common/CanvasEntityHeaderWarnings.tsx | 97 +++++++++++++++++++ 2 files changed, 99 insertions(+) create mode 100644 invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderCommonActions.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderCommonActions.tsx index cc89cb02f6..9c22778ff6 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderCommonActions.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderCommonActions.tsx @@ -1,6 +1,7 @@ import { Flex } from '@invoke-ai/ui-library'; import { CanvasEntityDeleteButton } from 'features/controlLayers/components/common/CanvasEntityDeleteButton'; import { CanvasEntityEnabledToggle } from 'features/controlLayers/components/common/CanvasEntityEnabledToggle'; +import { CanvasEntityHeaderWarnings } from 'features/controlLayers/components/common/CanvasEntityHeaderWarnings'; import { CanvasEntityIsBookmarkedForQuickSwitchToggle } from 'features/controlLayers/components/common/CanvasEntityIsBookmarkedForQuickSwitchToggle'; import { CanvasEntityIsLockedToggle } from 'features/controlLayers/components/common/CanvasEntityIsLockedToggle'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; @@ -11,6 +12,7 @@ export const CanvasEntityHeaderCommonActions = memo(() => { return ( + {entityIdentifier.type !== 'reference_image' && } diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx new file mode 100644 index 0000000000..9221b81249 --- /dev/null +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx @@ -0,0 +1,97 @@ +import { IconButton, ListItem, UnorderedList } from '@invoke-ai/ui-library'; +import { createSelector } from '@reduxjs/toolkit'; +import { EMPTY_ARRAY } from 'app/store/constants'; +import { useAppSelector } from 'app/store/storeHooks'; +import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; +import { useEntityIsEnabled } from 'features/controlLayers/hooks/useEntityIsEnabled'; +import { selectModel } from 'features/controlLayers/store/paramsSlice'; +import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors'; +import type { CanvasEntityIdentifier } from 'features/controlLayers/store/types'; +import { + getControlLayerWarnings, + getGlobalReferenceImageWarnings, + getInpaintMaskWarnings, + getRasterLayerWarnings, + getRegionalGuidanceWarnings, +} from 'features/controlLayers/store/validators'; +import type { TFunction } from 'i18next'; +import { upperFirst } from 'lodash-es'; +import { memo, useMemo } from 'react'; +import { useTranslation } from 'react-i18next'; +import { PiWarningBold } from 'react-icons/pi'; +import type { Equals } from 'tsafe'; +import { assert } from 'tsafe'; + +const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunction) => { + return createSelector(selectCanvasSlice, selectModel, (canvas, model) => { + // This component is used within a so we can safely assume that the entity exists. + // Should never throw. + const entity = selectEntityOrThrow(canvas, entityIdentifier, 'CanvasEntityHeaderWarnings'); + + let warnings: string[] = []; + + const entityType = entity.type; + + if (entityType === 'control_layer') { + warnings = getControlLayerWarnings(entity, model, t); + } else if (entityType === 'regional_guidance') { + warnings = getRegionalGuidanceWarnings(entity, model, t); + } else if (entityType === 'inpaint_mask') { + warnings = getInpaintMaskWarnings(entity, model, t); + } else if (entityType === 'raster_layer') { + warnings = getRasterLayerWarnings(entity, model, t); + } else if (entityType === 'reference_image') { + warnings = getGlobalReferenceImageWarnings(entity, model, t); + } else { + assert>(false, 'Unexpected entity type'); + } + + // Return a stable reference if there are no warnings + if (warnings.length === 0) { + return EMPTY_ARRAY; + } + + return warnings.map(upperFirst); + }); +}; + +export const CanvasEntityHeaderWarnings = memo(() => { + const entityIdentifier = useEntityIdentifierContext(); + const { t } = useTranslation(); + const isEnabled = useEntityIsEnabled(entityIdentifier); + const selectWarnings = useMemo(() => buildSelectWarnings(entityIdentifier, t), [entityIdentifier, t]); + const warnings = useAppSelector(selectWarnings); + + if (warnings.length === 0) { + return null; + } + + return ( + // Using IconButton here bc it matches the styling of the actual buttons in the header without any fanagling, but + // it's not a button + } + icon={} + colorScheme="warning" + isDisabled={!isEnabled} + /> + ); +}); + +CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings'; + +const TooltipContent = memo((props: { warnings: string[] }) => { + return ( + + {props.warnings.map((warning, index) => ( + {warning} + ))} + + ); +}); +TooltipContent.displayName = 'TooltipContent'; From 0be796a808523a95edb4985aabe4d10848a9a225 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:14:26 +1000 Subject: [PATCH 22/28] feat(ui): use layer validation utils in invoke readiness utils --- .../web/src/features/queue/store/readiness.ts | 74 +++++++------------ 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index 0d14bba73d..b44e4e6266 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -4,6 +4,13 @@ import type { ParamsState } from 'features/controlLayers/store/paramsSlice'; import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasSlice } from 'features/controlLayers/store/selectors'; import type { CanvasState } from 'features/controlLayers/store/types'; +import { + getControlLayerWarnings, + getGlobalReferenceImageWarnings, + getInpaintMaskWarnings, + getRasterLayerWarnings, + getRegionalGuidanceWarnings, +} from 'features/controlLayers/store/validators'; import type { DynamicPromptsState } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { selectDynamicPromptsSlice } from 'features/dynamicPrompts/store/dynamicPromptsSlice'; import { getShouldProcessPrompt } from 'features/dynamicPrompts/util/getShouldProcessPrompt'; @@ -278,15 +285,8 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - // Must have model - if (!controlLayer.controlAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterNoModelSelected')); - } - // Model base must match - if (controlLayer.controlAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); - } + const problems = getControlLayerWarnings(controlLayer, model, i18n.t); + if (problems.length) { const content = upperFirst(problems.join(', ')); reasons.push({ prefix, content }); @@ -300,20 +300,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - - // Must have model - if (!entity.ipAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); - } - // Model base must match - if (entity.ipAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); - } - // Must have an image - if (!entity.ipAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); - } + const problems = getGlobalReferenceImageWarnings(entity, model, i18n.t); if (problems.length) { const content = upperFirst(problems.join(', ')); @@ -328,29 +315,7 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; - // Must have a region - if (entity.objects.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoRegion')); - } - // Must have at least 1 prompt or IP Adapter - if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) { - problems.push(i18n.t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); - } - entity.referenceImages.forEach(({ ipAdapter }) => { - // Must have model - if (!ipAdapter.model) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoModelSelected')); - } - // Model base must match - if (ipAdapter.model?.base !== model?.base) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); - } - // Must have an image - if (!ipAdapter.image) { - problems.push(i18n.t('parameters.invoke.layer.ipAdapterNoImageSelected')); - } - }); + const problems = getRegionalGuidanceWarnings(entity, model, i18n.t); if (problems.length) { const content = upperFirst(problems.join(', ')); @@ -365,7 +330,22 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems: string[] = []; + const problems = getRasterLayerWarnings(entity, model, i18n.t); + + if (problems.length) { + const content = upperFirst(problems.join(', ')); + reasons.push({ prefix, content }); + } + }); + + canvas.inpaintMasks.entities + .filter((entity) => entity.isEnabled) + .forEach((entity, i) => { + const layerLiteral = i18n.t('controlLayers.layer_one'); + const layerNumber = i + 1; + const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); + const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; + const problems = getInpaintMaskWarnings(entity, model, i18n.t); if (problems.length) { const content = upperFirst(problems.join(', ')); From 3905c97e32b47956293e97903e81667f53c7fd81 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:25:09 +1000 Subject: [PATCH 23/28] feat(ui): return translation keys from validation utils instead of translated strings --- .../common/CanvasEntityHeaderWarnings.tsx | 12 ++-- .../controlLayers/store/validators.ts | 63 +++++++------------ .../web/src/features/queue/store/readiness.ts | 20 +++--- 3 files changed, 40 insertions(+), 55 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx index 9221b81249..c6ef5ad7db 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx @@ -33,15 +33,15 @@ const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunct const entityType = entity.type; if (entityType === 'control_layer') { - warnings = getControlLayerWarnings(entity, model, t); + warnings = getControlLayerWarnings(entity, model); } else if (entityType === 'regional_guidance') { - warnings = getRegionalGuidanceWarnings(entity, model, t); + warnings = getRegionalGuidanceWarnings(entity, model); } else if (entityType === 'inpaint_mask') { - warnings = getInpaintMaskWarnings(entity, model, t); + warnings = getInpaintMaskWarnings(entity, model); } else if (entityType === 'raster_layer') { - warnings = getRasterLayerWarnings(entity, model, t); + warnings = getRasterLayerWarnings(entity, model); } else if (entityType === 'reference_image') { - warnings = getGlobalReferenceImageWarnings(entity, model, t); + warnings = getGlobalReferenceImageWarnings(entity, model); } else { assert>(false, 'Unexpected entity type'); } @@ -51,7 +51,7 @@ const buildSelectWarnings = (entityIdentifier: CanvasEntityIdentifier, t: TFunct return EMPTY_ARRAY; } - return warnings.map(upperFirst); + return warnings.map((w) => t(w)).map(upperFirst); }); }; diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 604ab3338b..5b63572140 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -6,52 +6,50 @@ import type { CanvasRegionalGuidanceState, } from 'features/controlLayers/store/types'; import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; -import type { TFunction } from 'i18next'; export const getRegionalGuidanceWarnings = ( entity: CanvasRegionalGuidanceState, - model: ParameterModel | null, - t: TFunction + model: ParameterModel | null ): string[] => { const warnings: string[] = []; if (entity.objects.length === 0) { // Layer is in empty state - skip other checks - warnings.push(t('parameters.invoke.layer.emptyLayer')); + warnings.push('parameters.invoke.layer.emptyLayer'); } else { if (entity.positivePrompt === null && entity.negativePrompt === null && entity.referenceImages.length === 0) { // Must have at least 1 prompt or IP Adapter - warnings.push(t('parameters.invoke.layer.rgNoPromptsOrIPAdapters')); + warnings.push('parameters.invoke.layer.rgNoPromptsOrIPAdapters'); } if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push(t('controlLayers.invalidBaseModelType')); + warnings.push('controlLayers.invalidBaseModelType'); } else if (model.base === 'flux') { // Some features are not supported for flux models if (entity.negativePrompt !== null) { - warnings.push(t('parameters.invoke.layer.rgNegativePromptNotSupported')); + warnings.push('parameters.invoke.layer.rgNegativePromptNotSupported'); } if (entity.referenceImages.length > 0) { - warnings.push(t('parameters.invoke.layer.rgReferenceImagesNotSupported')); + warnings.push('parameters.invoke.layer.rgReferenceImagesNotSupported'); } if (entity.autoNegative) { - warnings.push(t('parameters.invoke.layer.rgAutoNegativeNotSupported')); + warnings.push('parameters.invoke.layer.rgAutoNegativeNotSupported'); } } else { entity.referenceImages.forEach(({ ipAdapter }) => { if (!ipAdapter.model) { // No model selected - warnings.push(t('parameters.invoke.layer.ipAdapterNoModelSelected')); + warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected'); } else if (ipAdapter.model.base !== model.base) { // Supported model architecture but doesn't match - warnings.push(t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); + warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'); } if (!ipAdapter.image) { // No image selected - warnings.push(t('parameters.invoke.layer.ipAdapterNoImageSelected')); + warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected'); } }); } @@ -63,53 +61,48 @@ export const getRegionalGuidanceWarnings = ( export const getGlobalReferenceImageWarnings = ( entity: CanvasReferenceImageState, - model: ParameterModel | null, - t: TFunction + model: ParameterModel | null ): string[] => { const warnings: string[] = []; if (!entity.ipAdapter.model) { // No model selected - warnings.push(t('parameters.invoke.layer.ipAdapterNoModelSelected')); + warnings.push('parameters.invoke.layer.ipAdapterNoModelSelected'); } else if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push(t('controlLayers.invalidBaseModelType')); + warnings.push('controlLayers.invalidBaseModelType'); } else if (entity.ipAdapter.model.base !== model.base) { // Supported model architecture but doesn't match - warnings.push(t('parameters.invoke.layer.ipAdapterIncompatibleBaseModel')); + warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'); } } if (!entity.ipAdapter.image) { // No image selected - warnings.push(t('parameters.invoke.layer.ipAdapterNoImageSelected')); + warnings.push('parameters.invoke.layer.ipAdapterNoImageSelected'); } return warnings; }; -export const getControlLayerWarnings = ( - entity: CanvasControlLayerState, - model: ParameterModel | null, - t: TFunction -): string[] => { +export const getControlLayerWarnings = (entity: CanvasControlLayerState, model: ParameterModel | null): string[] => { const warnings: string[] = []; if (entity.objects.length === 0) { // Layer is in empty state - skip other checks - warnings.push(t('parameters.invoke.layer.emptyLayer')); + warnings.push('parameters.invoke.layer.emptyLayer'); } else { if (!entity.controlAdapter.model) { // No model selected - warnings.push(t('parameters.invoke.layer.controlAdapterNoModelSelected')); + warnings.push('parameters.invoke.layer.controlAdapterNoModelSelected'); } else if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push(t('controlLayers.invalidBaseModelType')); + warnings.push('controlLayers.invalidBaseModelType'); } else if (entity.controlAdapter.model.base !== model.base) { // Supported model architecture but doesn't match - warnings.push(t('parameters.invoke.layer.controlAdapterIncompatibleBaseModel')); + warnings.push('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'); } } } @@ -117,31 +110,23 @@ export const getControlLayerWarnings = ( return warnings; }; -export const getRasterLayerWarnings = ( - entity: CanvasRasterLayerState, - _model: ParameterModel | null, - t: TFunction -): string[] => { +export const getRasterLayerWarnings = (entity: CanvasRasterLayerState, _model: ParameterModel | null): string[] => { const warnings: string[] = []; if (entity.objects.length === 0) { // Layer is in empty state - skip other checks - warnings.push(t('parameters.invoke.layer.emptyLayer')); + warnings.push('parameters.invoke.layer.emptyLayer'); } return warnings; }; -export const getInpaintMaskWarnings = ( - entity: CanvasInpaintMaskState, - _model: ParameterModel | null, - t: TFunction -): string[] => { +export const getInpaintMaskWarnings = (entity: CanvasInpaintMaskState, _model: ParameterModel | null): string[] => { const warnings: string[] = []; if (entity.objects.length === 0) { // Layer is in empty state - skip other checks - warnings.push(t('parameters.invoke.layer.emptyLayer')); + warnings.push('parameters.invoke.layer.emptyLayer'); } return warnings; diff --git a/invokeai/frontend/web/src/features/queue/store/readiness.ts b/invokeai/frontend/web/src/features/queue/store/readiness.ts index b44e4e6266..0af16bdf78 100644 --- a/invokeai/frontend/web/src/features/queue/store/readiness.ts +++ b/invokeai/frontend/web/src/features/queue/store/readiness.ts @@ -285,10 +285,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY['control_layer']); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems = getControlLayerWarnings(controlLayer, model, i18n.t); + const problems = getControlLayerWarnings(controlLayer, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -300,10 +300,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems = getGlobalReferenceImageWarnings(entity, model, i18n.t); + const problems = getGlobalReferenceImageWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -315,10 +315,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems = getRegionalGuidanceWarnings(entity, model, i18n.t); + const problems = getRegionalGuidanceWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -330,10 +330,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems = getRasterLayerWarnings(entity, model, i18n.t); + const problems = getRasterLayerWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); @@ -345,10 +345,10 @@ const getReasonsWhyCannotEnqueueCanvasTab = (arg: { const layerNumber = i + 1; const layerType = i18n.t(LAYER_TYPE_TO_TKEY[entity.type]); const prefix = `${layerLiteral} #${layerNumber} (${layerType})`; - const problems = getInpaintMaskWarnings(entity, model, i18n.t); + const problems = getInpaintMaskWarnings(entity, model); if (problems.length) { - const content = upperFirst(problems.join(', ')); + const content = upperFirst(problems.map((p) => i18n.t(p)).join(', ')); reasons.push({ prefix, content }); } }); From df0c7d73f36a7927055a692ea18e98743563fe99 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:26:09 +1000 Subject: [PATCH 24/28] feat(ui): use regional guidance validation utils in graph builders --- .../nodes/util/graph/generation/addRegions.ts | 50 +++++++------------ .../util/graph/generation/buildFLUXGraph.ts | 2 +- .../util/graph/generation/buildSD1Graph.ts | 2 +- .../util/graph/generation/buildSDXLGraph.ts | 2 +- 4 files changed, 20 insertions(+), 36 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts index 1c058c5f4c..a1921200d5 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addRegions.ts @@ -3,15 +3,12 @@ import { deepClone } from 'common/util/deepClone'; import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; import { getPrefixedId } from 'features/controlLayers/konva/util'; -import type { - CanvasRegionalGuidanceState, - IPAdapterConfig, - Rect, - RegionalGuidanceReferenceImageState, -} from 'features/controlLayers/store/types'; +import type { CanvasRegionalGuidanceState, Rect } from 'features/controlLayers/store/types'; +import { getRegionalGuidanceWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { serializeError } from 'serialize-error'; -import type { BaseModelType, Invocation } from 'services/api/types'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); @@ -23,19 +20,12 @@ type AddedRegionResult = { addedIPAdapters: number; }; -const isValidRegion = (rg: CanvasRegionalGuidanceState, base: BaseModelType) => { - const isEnabled = rg.isEnabled; - const hasTextPrompt = Boolean(rg.positivePrompt || rg.negativePrompt); - const hasIPAdapter = rg.referenceImages.filter(({ ipAdapter }) => isValidIPAdapter(ipAdapter, base)).length > 0; - return isEnabled && (hasTextPrompt || hasIPAdapter); -}; - type AddRegionsArg = { manager: CanvasManager; regions: CanvasRegionalGuidanceState[]; g: Graph; bbox: Rect; - base: BaseModelType; + model: ParameterModel; posCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'>; negCond: Invocation<'compel' | 'sdxl_compel_prompt' | 'flux_text_encoder'> | null; posCondCollect: Invocation<'collect'>; @@ -49,7 +39,7 @@ type AddRegionsArg = { * @param regions Array of regions to add * @param g The graph to add the layers to * @param bbox The bounding box - * @param base The base model type + * @param model The main model * @param posCond The positive conditioning node * @param negCond The negative conditioning node * @param posCondCollect The positive conditioning collector @@ -63,17 +53,23 @@ export const addRegions = async ({ regions, g, bbox, - base, + model, posCond, negCond, posCondCollect, negCondCollect, ipAdapterCollect, }: AddRegionsArg): Promise => { - const isSDXL = base === 'sdxl'; - const isFLUX = base === 'flux'; + const isSDXL = model.base === 'sdxl'; + const isFLUX = model.base === 'flux'; + + const validRegions = regions.filter((rg) => { + if (!rg.isEnabled) { + return false; + } + return getRegionalGuidanceWarnings(rg, model).length === 0; + }); - const validRegions = regions.filter((rg) => isValidRegion(rg, base)); const results: AddedRegionResult[] = []; for (const region of validRegions) { @@ -275,11 +271,7 @@ export const addRegions = async ({ } } - const validRGIPAdapters: RegionalGuidanceReferenceImageState[] = region.referenceImages.filter(({ ipAdapter }) => - isValidIPAdapter(ipAdapter, base) - ); - - for (const { id, ipAdapter } of validRGIPAdapters) { + for (const { id, ipAdapter } of region.referenceImages) { assert(!isFLUX, 'Regional IP adapters are not supported for FLUX.'); result.addedIPAdapters++; @@ -313,11 +305,3 @@ export const addRegions = async ({ return results; }; - -const isValidIPAdapter = (ipAdapter: IPAdapterConfig, base: BaseModelType): boolean => { - // Must be a model that matches the current base and must have a control image - const hasModel = Boolean(ipAdapter.model); - const modelMatchesBase = ipAdapter.model?.base === base; - const hasImage = Boolean(ipAdapter.image); - return hasModel && modelMatchesBase && hasImage; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 4b3cad0774..f8d310ee7e 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -224,7 +224,7 @@ export const buildFLUXGraph = async ( regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, - base: modelConfig.base, + model: modelConfig, posCond, negCond: null, posCondCollect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index ab38035b4a..58b195ed61 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -270,7 +270,7 @@ export const buildSD1Graph = async ( regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, - base: modelConfig.base, + model: modelConfig, posCond, negCond, posCondCollect, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 4d84c025ec..8aff599743 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -275,7 +275,7 @@ export const buildSDXLGraph = async ( regions: canvas.regionalGuidance.entities, g, bbox: canvas.bbox.rect, - base: modelConfig.base, + model: modelConfig, posCond, negCond, posCondCollect, From 46a09d9e908c9297c761d335741b9359614f847f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 13:32:51 +1000 Subject: [PATCH 25/28] feat(ui): format warnings tooltip --- invokeai/frontend/web/public/locales/en.json | 3 ++- .../common/CanvasEntityHeaderWarnings.tsx | 16 ++++++++++------ 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index dfb75eb2a8..c8771a5c44 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -176,7 +176,8 @@ "reset": "Reset", "none": "None", "new": "New", - "generating": "Generating" + "generating": "Generating", + "warnings": "Warnings" }, "hrf": { "hrf": "High Resolution Fix", diff --git a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx index c6ef5ad7db..f13de15c19 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/common/CanvasEntityHeaderWarnings.tsx @@ -1,4 +1,4 @@ -import { IconButton, ListItem, UnorderedList } from '@invoke-ai/ui-library'; +import { Flex, IconButton, ListItem, Text, UnorderedList } from '@invoke-ai/ui-library'; import { createSelector } from '@reduxjs/toolkit'; import { EMPTY_ARRAY } from 'app/store/constants'; import { useAppSelector } from 'app/store/storeHooks'; @@ -86,12 +86,16 @@ export const CanvasEntityHeaderWarnings = memo(() => { CanvasEntityHeaderWarnings.displayName = 'CanvasEntityHeaderWarnings'; const TooltipContent = memo((props: { warnings: string[] }) => { + const { t } = useTranslation(); return ( - - {props.warnings.map((warning, index) => ( - {warning} - ))} - + + {t('common.warnings')}: + + {props.warnings.map((warning, index) => ( + {warning} + ))} + + ); }); TooltipContent.displayName = 'TooltipContent'; From 08704ee8ec888c4a636c1eaa0bbb40e279cbe4e6 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:32:48 +1000 Subject: [PATCH 26/28] feat(ui): use canvas layer validators in control/ip adapter graph builders --- .../graph/generation/addControlAdapters.ts | 85 ++++++++++--------- .../util/graph/generation/addIPAdapters.ts | 28 +++--- .../util/graph/generation/buildFLUXGraph.ts | 19 +++-- .../util/graph/generation/buildSD1Graph.ts | 31 ++++--- .../util/graph/generation/buildSDXLGraph.ts | 31 ++++--- 5 files changed, 106 insertions(+), 88 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts index d012ca2853..a41dc06550 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addControlAdapters.ts @@ -1,35 +1,41 @@ import { logger } from 'app/logging/logger'; import { withResultAsync } from 'common/util/result'; import type { CanvasManager } from 'features/controlLayers/konva/CanvasManager'; -import type { - CanvasControlLayerState, - ControlNetConfig, - Rect, - T2IAdapterConfig, -} from 'features/controlLayers/store/types'; +import type { CanvasControlLayerState, Rect } from 'features/controlLayers/store/types'; +import { getControlLayerWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; import { serializeError } from 'serialize-error'; -import type { BaseModelType, ImageDTO, Invocation } from 'services/api/types'; +import type { ImageDTO, Invocation } from 'services/api/types'; import { assert } from 'tsafe'; const log = logger('system'); +type AddControlNetsArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddControlNetsResult = { addedControlNets: number; }; -export const addControlNets = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 'controlnet'); +export const addControlNets = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddControlNetsArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 'controlnet') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddControlNetsResult = { addedControlNets: 0, @@ -54,22 +60,31 @@ export const addControlNets = async ( return result; }; +type AddT2IAdaptersArg = { + manager: CanvasManager; + entities: CanvasControlLayerState[]; + g: Graph; + rect: Rect; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + type AddT2IAdaptersResult = { addedT2IAdapters: number; }; -export const addT2IAdapters = async ( - manager: CanvasManager, - layers: CanvasControlLayerState[], - g: Graph, - rect: Rect, - collector: Invocation<'collect'>, - base: BaseModelType -): Promise => { - const validControlLayers = layers - .filter((layer) => layer.isEnabled) - .filter((layer) => isValidControlAdapter(layer.controlAdapter, base)) - .filter((layer) => layer.controlAdapter.type === 't2i_adapter'); +export const addT2IAdapters = async ({ + manager, + entities, + g, + rect, + collector, + model, +}: AddT2IAdaptersArg): Promise => { + const validControlLayers = entities + .filter((entity) => entity.isEnabled) + .filter((entity) => entity.controlAdapter.type === 't2i_adapter') + .filter((entity) => getControlLayerWarnings(entity, model).length === 0); const result: AddT2IAdaptersResult = { addedT2IAdapters: 0, @@ -145,11 +160,3 @@ const addT2IAdapterToGraph = ( g.addEdge(t2iAdapter, 't2i_adapter', collector, 'item'); }; - -const isValidControlAdapter = (controlAdapter: ControlNetConfig | T2IAdapterConfig, base: BaseModelType): boolean => { - // Must be have a model - const hasModel = Boolean(controlAdapter.model); - // Model must match the current base model - const modelMatchesBase = controlAdapter.model?.base === base; - return hasModel && modelMatchesBase; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts index 81b98f3ef5..0a3a43a018 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addIPAdapters.ts @@ -1,19 +1,23 @@ import type { CanvasReferenceImageState } from 'features/controlLayers/store/types'; +import { getGlobalReferenceImageWarnings } from 'features/controlLayers/store/validators'; import type { Graph } from 'features/nodes/util/graph/generation/Graph'; -import type { BaseModelType, Invocation } from 'services/api/types'; +import type { ParameterModel } from 'features/parameters/types/parameterSchemas'; +import type { Invocation } from 'services/api/types'; import { assert } from 'tsafe'; type AddIPAdaptersResult = { addedIPAdapters: number; }; -export const addIPAdapters = ( - ipAdapters: CanvasReferenceImageState[], - g: Graph, - collector: Invocation<'collect'>, - base: BaseModelType -): AddIPAdaptersResult => { - const validIPAdapters = ipAdapters.filter((entity) => isValidIPAdapter(entity, base)); +type AddIPAdaptersArg = { + entities: CanvasReferenceImageState[]; + g: Graph; + collector: Invocation<'collect'>; + model: ParameterModel; +}; + +export const addIPAdapters = ({ entities, g, collector, model }: AddIPAdaptersArg): AddIPAdaptersResult => { + const validIPAdapters = entities.filter((entity) => getGlobalReferenceImageWarnings(entity, model).length === 0); const result: AddIPAdaptersResult = { addedIPAdapters: 0, @@ -76,11 +80,3 @@ const addIPAdapter = (entity: CanvasReferenceImageState, g: Graph, collector: In g.addEdge(ipAdapterNode, 'ip_adapter', collector, 'item'); }; - -const isValidIPAdapter = ({ isEnabled, ipAdapter }: CanvasReferenceImageState, base: BaseModelType): boolean => { - // Must be have a model that matches the current base and must have a control image - const hasModel = Boolean(ipAdapter.model); - const modelMatchesBase = ipAdapter.model?.base === base; - const hasImage = Boolean(ipAdapter.image); - return isEnabled && hasModel && modelMatchesBase && hasImage; -}; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index f8d310ee7e..885954e995 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -199,14 +199,14 @@ export const buildFLUXGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -217,7 +217,12 @@ export const buildFLUXGraph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts index 58b195ed61..7522227007 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSD1Graph.ts @@ -227,14 +227,14 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -245,14 +245,14 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { @@ -263,7 +263,12 @@ export const buildSD1Graph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager, diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts index 8aff599743..9357a291b4 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildSDXLGraph.ts @@ -232,14 +232,14 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('control_net_collector'), }); - const controlNetResult = await addControlNets( + const controlNetResult = await addControlNets({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - controlNetCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: controlNetCollector, + model: modelConfig, + }); if (controlNetResult.addedControlNets > 0) { g.addEdge(controlNetCollector, 'collection', denoise, 'control'); } else { @@ -250,14 +250,14 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('t2i_adapter_collector'), }); - const t2iAdapterResult = await addT2IAdapters( + const t2iAdapterResult = await addT2IAdapters({ manager, - canvas.controlLayers.entities, + entities: canvas.controlLayers.entities, g, - canvas.bbox.rect, - t2iAdapterCollector, - modelConfig.base - ); + rect: canvas.bbox.rect, + collector: t2iAdapterCollector, + model: modelConfig, + }); if (t2iAdapterResult.addedT2IAdapters > 0) { g.addEdge(t2iAdapterCollector, 'collection', denoise, 't2i_adapter'); } else { @@ -268,7 +268,12 @@ export const buildSDXLGraph = async ( type: 'collect', id: getPrefixedId('ip_adapter_collector'), }); - const ipAdapterResult = addIPAdapters(canvas.referenceImages.entities, g, ipAdapterCollect, modelConfig.base); + const ipAdapterResult = addIPAdapters({ + entities: canvas.referenceImages.entities, + g, + collector: ipAdapterCollect, + model: modelConfig, + }); const regionsResult = await addRegions({ manager, From 4d7667f63d68235a6a02025bb3e66cd166b71642 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:43:49 +1000 Subject: [PATCH 27/28] fix(ui): add missing translations --- invokeai/frontend/web/public/locales/en.json | 4 ++++ .../web/src/features/controlLayers/store/validators.ts | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index 916cea07c1..3b04bcff71 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -1041,6 +1041,7 @@ "noNodesInGraph": "No nodes in graph", "systemDisconnected": "System disconnected", "layer": { + "unsupportedModel": "layer not supported for selected base model", "controlAdapterNoModelSelected": "no Control Adapter model selected", "controlAdapterIncompatibleBaseModel": "incompatible Control Adapter base model", "t2iAdapterIncompatibleBboxWidth": "$t(parameters.invoke.layer.t2iAdapterRequiresDimensionsToBeMultipleOf) {{multiple}}, bbox width is {{width}}", @@ -1051,6 +1052,9 @@ "ipAdapterIncompatibleBaseModel": "incompatible IP Adapter base model", "ipAdapterNoImageSelected": "no IP Adapter image selected", "rgNoPromptsOrIPAdapters": "no text prompts or IP Adapters", + "rgNegativePromptNotSupported": "negative prompt not supported for selected base model", + "rgReferenceImagesNotSupported": "regional reference images not supported for selected base model", + "rgAutoNegativeNotSupported": "auto-negative not supported for selected base model", "emptyLayer": "empty layer" } }, diff --git a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts index 5b63572140..0b8a34c3bd 100644 --- a/invokeai/frontend/web/src/features/controlLayers/store/validators.ts +++ b/invokeai/frontend/web/src/features/controlLayers/store/validators.ts @@ -25,7 +25,7 @@ export const getRegionalGuidanceWarnings = ( if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push('controlLayers.invalidBaseModelType'); + warnings.push('parameters.invoke.layer.unsupportedModel'); } else if (model.base === 'flux') { // Some features are not supported for flux models if (entity.negativePrompt !== null) { @@ -71,7 +71,7 @@ export const getGlobalReferenceImageWarnings = ( } else if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push('controlLayers.invalidBaseModelType'); + warnings.push('parameters.invoke.layer.unsupportedModel'); } else if (entity.ipAdapter.model.base !== model.base) { // Supported model architecture but doesn't match warnings.push('parameters.invoke.layer.ipAdapterIncompatibleBaseModel'); @@ -99,7 +99,7 @@ export const getControlLayerWarnings = (entity: CanvasControlLayerState, model: } else if (model) { if (model.base === 'sd-3' || model.base === 'sd-2') { // Unsupported model architecture - warnings.push('controlLayers.invalidBaseModelType'); + warnings.push('parameters.invoke.layer.unsupportedModel'); } else if (entity.controlAdapter.model.base !== model.base) { // Supported model architecture but doesn't match warnings.push('parameters.invoke.layer.controlAdapterIncompatibleBaseModel'); From 7d488a53528450a47f248a46caa1fb0547a2a5ec Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 29 Nov 2024 15:51:24 +1000 Subject: [PATCH 28/28] feat(ui): add delete button to regional ref image empty state --- .../RegionalGuidanceDeletePromptButton.tsx | 33 ++++++++++--------- ...nalGuidanceIPAdapterSettingsEmptyState.tsx | 6 ++++ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx index 2fc7483756..3e11bdc4e2 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton.tsx @@ -1,27 +1,28 @@ -import { IconButton, Tooltip } from '@invoke-ai/ui-library'; +import type { IconButtonProps } from '@invoke-ai/ui-library'; +import { IconButton } from '@invoke-ai/ui-library'; import { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { PiTrashSimpleFill } from 'react-icons/pi'; +import { PiXBold } from 'react-icons/pi'; -type Props = { +type Props = Omit & { onDelete: () => void; }; -export const RegionalGuidanceDeletePromptButton = memo(({ onDelete }: Props) => { +export const RegionalGuidanceDeletePromptButton = memo(({ onDelete, ...rest }: Props) => { const { t } = useTranslation(); return ( - - } - onClick={onDelete} - flexGrow={0} - size="sm" - p={0} - colorScheme="error" - /> - + } + onClick={onDelete} + flexGrow={0} + size="sm" + p={0} + colorScheme="error" + {...rest} + /> ); }); diff --git a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx index c98d6bc23d..722720dfbf 100644 --- a/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx +++ b/invokeai/frontend/web/src/features/controlLayers/components/RegionalGuidance/RegionalGuidanceIPAdapterSettingsEmptyState.tsx @@ -1,8 +1,10 @@ import { Button, Flex, Text } from '@invoke-ai/ui-library'; import { useAppDispatch } from 'app/store/storeHooks'; import { useImageUploadButton } from 'common/hooks/useImageUploadButton'; +import { RegionalGuidanceDeletePromptButton } from 'features/controlLayers/components/RegionalGuidance/RegionalGuidanceDeletePromptButton'; import { useEntityIdentifierContext } from 'features/controlLayers/contexts/EntityIdentifierContext'; import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy'; +import { rgIPAdapterDeleted } from 'features/controlLayers/store/canvasSlice'; import type { SetRegionalGuidanceReferenceImageDndTargetData } from 'features/dnd/dnd'; import { setRegionalGuidanceReferenceImageDndTarget } from 'features/dnd/dnd'; import { DndDropTarget } from 'features/dnd/DndDropTarget'; @@ -31,6 +33,9 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag const onClickGalleryButton = useCallback(() => { dispatch(activeTabCanvasRightPanelChanged('gallery')); }, [dispatch]); + const onDeleteIPAdapter = useCallback(() => { + dispatch(rgIPAdapterDeleted({ entityIdentifier, referenceImageId })); + }, [dispatch, entityIdentifier, referenceImageId]); const dndTargetData = useMemo( () => @@ -43,6 +48,7 @@ export const RegionalGuidanceIPAdapterSettingsEmptyState = memo(({ referenceImag return ( +