From fda7aaa7ca74925736bdfe339d08f3ce7ebf29a6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 20 Nov 2024 19:48:04 +0000 Subject: [PATCH] Pass RegionalPromptingExtension down to the CustomDoubleStreamBlockProcessor in FLUX. --- invokeai/app/invocations/flux_denoise.py | 160 +++++------------- .../backend/flux/custom_block_processor.py | 2 + invokeai/backend/flux/denoise.py | 29 ++-- .../regional_prompting_extension.py | 96 +++++++++++ invokeai/backend/flux/model.py | 3 + 5 files changed, 159 insertions(+), 131 deletions(-) create mode 100644 invokeai/backend/flux/extensions/regional_prompting_extension.py diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9faa1207d2..67bcfc785f 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -4,7 +4,6 @@ from typing import Callable, Iterator, Optional, Tuple import numpy as np import numpy.typing as npt import torch -import torchvision import torchvision.transforms as tv_transforms from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -31,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN from invokeai.backend.flux.denoise import denoise from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux @@ -43,15 +43,14 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) -from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning +from invokeai.backend.flux.text_conditioning import FluxTextConditioning from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.mask import to_standard_float_mask @invocation( @@ -142,113 +141,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) - @staticmethod - def _preprocess_regional_prompt_mask( - mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype - ) -> torch.Tensor: - """Preprocess a regional prompt mask to match the target height and width. - If mask is None, returns a mask of all ones with the target height and width. - If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. - - Returns: - torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). - """ - - if mask is None: - return torch.ones((1, 1, target_height, target_width), dtype=dtype) - - mask = to_standard_float_mask(mask, out_dtype=dtype) - - tf = torchvision.transforms.Resize( - (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST - ) - - # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). - mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) - resized_mask = tf(mask) - return resized_mask - - def _load_text_conditioning( - self, - context: InvocationContext, - cond_field: FluxConditioningField | list[FluxConditioningField], - latent_height: int, - latent_width: int, - dtype: torch.dtype, - ) -> list[FluxTextConditioning]: - """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" - # Normalize to a list of FluxConditioningFields. - cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field - - text_conditionings: list[FluxTextConditioning] = [] - for cond_field in cond_list: - # Load the text embeddings. - cond_data = context.conditioning.load(cond_field.conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) - t5_embeddings = flux_conditioning.t5_embeds - clip_embeddings = flux_conditioning.clip_embeds - - # Load the mask, if provided. - mask: Optional[torch.Tensor] = None - if cond_field.mask is not None: - mask = context.tensors.load(cond_field.mask.tensor_name) - mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) - - text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) - - return text_conditionings - - def _concat_regional_text_conditioning( - self, text_conditionings: list[FluxTextConditioning] - ) -> FluxRegionalTextConditioning: - """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" - concat_t5_embeddings: list[torch.Tensor] = [] - concat_clip_embeddings: list[torch.Tensor] = [] - concat_image_masks: list[torch.Tensor] = [] - concat_t5_embedding_ranges: list[Range] = [] - concat_clip_embedding_ranges: list[Range] = [] - - cur_t5_embedding_len = 0 - cur_clip_embedding_len = 0 - for text_conditioning in text_conditionings: - concat_t5_embeddings.append(text_conditioning.t5_embeddings) - concat_clip_embeddings.append(text_conditioning.clip_embeddings) - - concat_t5_embedding_ranges.append( - Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) - ) - concat_clip_embedding_ranges.append( - Range( - start=cur_clip_embedding_len, - end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], - ) - ) - - concat_image_masks.append(text_conditioning.mask) - - cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] - cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] - - t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) - - # Initialize the txt_ids tensor. - pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape - t5_txt_ids = torch.zeros( - pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() - ) - - return FluxRegionalTextConditioning( - t5_embeddings=t5_embeddings, - clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), - t5_txt_ids=t5_txt_ids, - image_masks=torch.cat(concat_image_masks, dim=1), - t5_embedding_ranges=concat_t5_embedding_ranges, - clip_embedding_ranges=concat_clip_embedding_ranges, - ) - def _run_diffusion( self, context: InvocationContext, @@ -288,10 +180,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): latent_width=latent_w, dtype=inference_dtype, ) - - pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings) - neg_regional_text_conditioning = ( - self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None + pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings) + neg_regional_prompting_extension = ( + RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings) + if neg_text_conditionings + else None ) transformer_info = context.models.load(self.transformer.transformer) @@ -436,8 +329,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): model=transformer, img=x, img_ids=img_ids, - pos_text_conditioning=pos_regional_text_conditioning, - neg_text_conditioning=neg_regional_text_conditioning, + pos_regional_prompting_extension=pos_regional_prompting_extension, + neg_regional_prompting_extension=neg_regional_prompting_extension, timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, @@ -451,6 +344,39 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): x = unpack(x.float(), self.height, self.width) return x + def _load_text_conditioning( + self, + context: InvocationContext, + cond_field: FluxConditioningField | list[FluxConditioningField], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> list[FluxTextConditioning]: + """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" + # Normalize to a list of FluxConditioningFields. + cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field + + text_conditionings: list[FluxTextConditioning] = [] + for cond_field in cond_list: + # Load the text embeddings. + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + # Load the mask, if provided. + mask: Optional[torch.Tensor] = None + if cond_field.mask is not None: + mask = context.tensors.load(cond_field.mask.tensor_name) + mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) + + text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) + + return text_conditionings + @classmethod def prep_cfg_scale( cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int diff --git a/invokeai/backend/flux/custom_block_processor.py b/invokeai/backend/flux/custom_block_processor.py index e0c7779e93..ae339cbd0e 100644 --- a/invokeai/backend/flux/custom_block_processor.py +++ b/invokeai/backend/flux/custom_block_processor.py @@ -1,6 +1,7 @@ import einops import torch +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.math import attention from invokeai.backend.flux.modules.layers import DoubleStreamBlock @@ -63,6 +64,7 @@ class CustomDoubleStreamBlockProcessor: vec: torch.Tensor, pe: torch.Tensor, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> tuple[torch.Tensor, torch.Tensor]: """A custom implementation of DoubleStreamBlock.forward() with additional features: - IP-Adapter support diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index c1cb3bbb71..ae541ba7d8 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -7,10 +7,10 @@ from tqdm import tqdm from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux -from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -19,8 +19,8 @@ def denoise( # model input img: torch.Tensor, img_ids: torch.Tensor, - pos_text_conditioning: FluxRegionalTextConditioning, - neg_text_conditioning: FluxRegionalTextConditioning | None, + pos_regional_prompting_extension: RegionalPromptingExtension, + neg_regional_prompting_extension: RegionalPromptingExtension | None, # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], @@ -50,16 +50,16 @@ def denoise( # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: - # FIX(ryand): Revive ControlNet functionality. + # TODO(ryand): Think about how to handle regional prompting with ControlNet. controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step_index, total_num_timesteps=total_steps, img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, ) @@ -74,9 +74,9 @@ def denoise( pred = model( img=img, img_ids=img_ids, - txt=txt, - txt_ids=txt_ids, - y=vec, + txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, @@ -84,6 +84,7 @@ def denoise( controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals, controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals, ip_adapter_extensions=pos_ip_adapter_extensions, + regional_prompting_extension=pos_regional_prompting_extension, ) step_cfg_scale = cfg_scale[step_index] @@ -93,15 +94,15 @@ def denoise( # TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance # on systems with sufficient VRAM. - if neg_txt is None or neg_txt_ids is None or neg_vec is None: + if neg_regional_prompting_extension is None: raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.") neg_pred = model( img=img, img_ids=img_ids, - txt=neg_txt, - txt_ids=neg_txt_ids, - y=neg_vec, + txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings, + txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids, + y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings, timesteps=t_vec, guidance=guidance_vec, timestep_index=step_index, diff --git a/invokeai/backend/flux/extensions/regional_prompting_extension.py b/invokeai/backend/flux/extensions/regional_prompting_extension.py new file mode 100644 index 0000000000..21b2279b14 --- /dev/null +++ b/invokeai/backend/flux/extensions/regional_prompting_extension.py @@ -0,0 +1,96 @@ +from typing import Optional + +import torch +import torchvision + +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range +from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask + + +class RegionalPromptingExtension: + """A class for managing regional prompting with FLUX.""" + + def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning): + self.regional_text_conditioning = regional_text_conditioning + + @classmethod + def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]): + return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning)) + + @classmethod + def _concat_regional_text_conditioning( + cls, + text_conditionings: list[FluxTextConditioning], + ) -> FluxRegionalTextConditioning: + """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" + concat_t5_embeddings: list[torch.Tensor] = [] + concat_clip_embeddings: list[torch.Tensor] = [] + concat_image_masks: list[torch.Tensor] = [] + concat_t5_embedding_ranges: list[Range] = [] + concat_clip_embedding_ranges: list[Range] = [] + + cur_t5_embedding_len = 0 + cur_clip_embedding_len = 0 + for text_conditioning in text_conditionings: + concat_t5_embeddings.append(text_conditioning.t5_embeddings) + concat_clip_embeddings.append(text_conditioning.clip_embeddings) + + concat_t5_embedding_ranges.append( + Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) + ) + concat_clip_embedding_ranges.append( + Range( + start=cur_clip_embedding_len, + end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], + ) + ) + + concat_image_masks.append(text_conditioning.mask) + + cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] + cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] + + t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) + + # Initialize the txt_ids tensor. + pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape + t5_txt_ids = torch.zeros( + pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() + ) + + return FluxRegionalTextConditioning( + t5_embeddings=t5_embeddings, + clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), + t5_txt_ids=t5_txt_ids, + image_masks=torch.cat(concat_image_masks, dim=1), + t5_embedding_ranges=concat_t5_embedding_ranges, + clip_embedding_ranges=concat_clip_embedding_ranges, + ) + + @staticmethod + def preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask diff --git a/invokeai/backend/flux/model.py b/invokeai/backend/flux/model.py index 0dadacd8fe..a0c44f0039 100644 --- a/invokeai/backend/flux/model.py +++ b/invokeai/backend/flux/model.py @@ -6,6 +6,7 @@ import torch from torch import Tensor, nn from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor +from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.modules.layers import ( DoubleStreamBlock, @@ -95,6 +96,7 @@ class Flux(nn.Module): controlnet_double_block_residuals: list[Tensor] | None, controlnet_single_block_residuals: list[Tensor] | None, ip_adapter_extensions: list[XLabsIPAdapterExtension], + regional_prompting_extension: RegionalPromptingExtension, ) -> Tensor: if img.ndim != 3 or txt.ndim != 3: raise ValueError("Input img and txt tensors must have 3 dimensions.") @@ -128,6 +130,7 @@ class Flux(nn.Module): vec=vec, pe=pe, ip_adapter_extensions=ip_adapter_extensions, + regional_prompting_extension=regional_prompting_extension, ) if controlnet_double_block_residuals is not None: