From 85c616fa34e21144dcb2d700f24b5f36a84e970f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 20 Nov 2024 18:51:43 +0000 Subject: [PATCH] WIP - Pass prompt masks to FLUX model during denoising. --- invokeai/app/invocations/fields.py | 5 + invokeai/app/invocations/flux_denoise.py | 180 +++++++++++++----- invokeai/app/invocations/flux_text_encoder.py | 11 +- invokeai/backend/flux/denoise.py | 12 +- invokeai/backend/flux/text_conditioning.py | 32 ++++ 5 files changed, 186 insertions(+), 54 deletions(-) create mode 100644 invokeai/backend/flux/text_conditioning.py diff --git a/invokeai/app/invocations/fields.py b/invokeai/app/invocations/fields.py index 5e76931933..16bfcaeed7 100644 --- a/invokeai/app/invocations/fields.py +++ b/invokeai/app/invocations/fields.py @@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel): """A conditioning tensor primitive value""" conditioning_name: str = Field(description="The name of conditioning tensor") + mask: Optional[TensorField] = Field( + default=None, + description="The mask associated with this conditioning tensor. Excluded regions should be set to False, " + "included regions should be set to True.", + ) class SD3ConditioningField(BaseModel): diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index 9e197626b5..9faa1207d2 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -4,6 +4,7 @@ from typing import Callable, Iterator, Optional, Tuple import numpy as np import numpy.typing as npt import torch +import torchvision import torchvision.transforms as tv_transforms from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection @@ -42,13 +43,15 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState -from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.mask import to_standard_float_mask @invocation( @@ -87,10 +90,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): input=Input.Connection, title="Transformer", ) - positive_text_conditioning: FluxConditioningField = InputField( + positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection ) - negative_text_conditioning: FluxConditioningField | None = InputField( + negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField( default=None, description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.", input=Input.Connection, @@ -139,18 +142,112 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): name = context.tensors.save(tensor=latents) return LatentsOutput.build(latents_name=name, latents=latents, seed=None) + @staticmethod + def _preprocess_regional_prompt_mask( + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + ) -> torch.Tensor: + """Preprocess a regional prompt mask to match the target height and width. + If mask is None, returns a mask of all ones with the target height and width. + If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation. + + Returns: + torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width). + """ + + if mask is None: + return torch.ones((1, 1, target_height, target_width), dtype=dtype) + + mask = to_standard_float_mask(mask, out_dtype=dtype) + + tf = torchvision.transforms.Resize( + (target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST + ) + + # Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w). + mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w) + resized_mask = tf(mask) + return resized_mask + def _load_text_conditioning( - self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Load the conditioning data. - cond_data = context.conditioning.load(conditioning_name) - assert len(cond_data.conditionings) == 1 - flux_conditioning = cond_data.conditionings[0] - assert isinstance(flux_conditioning, FLUXConditioningInfo) - flux_conditioning = flux_conditioning.to(dtype=dtype) - t5_embeddings = flux_conditioning.t5_embeds - clip_embeddings = flux_conditioning.clip_embeds - return t5_embeddings, clip_embeddings + self, + context: InvocationContext, + cond_field: FluxConditioningField | list[FluxConditioningField], + latent_height: int, + latent_width: int, + dtype: torch.dtype, + ) -> list[FluxTextConditioning]: + """Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields.""" + # Normalize to a list of FluxConditioningFields. + cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field + + text_conditionings: list[FluxTextConditioning] = [] + for cond_field in cond_list: + # Load the text embeddings. + cond_data = context.conditioning.load(cond_field.conditioning_name) + assert len(cond_data.conditionings) == 1 + flux_conditioning = cond_data.conditionings[0] + assert isinstance(flux_conditioning, FLUXConditioningInfo) + flux_conditioning = flux_conditioning.to(dtype=dtype) + t5_embeddings = flux_conditioning.t5_embeds + clip_embeddings = flux_conditioning.clip_embeds + + # Load the mask, if provided. + mask: Optional[torch.Tensor] = None + if cond_field.mask is not None: + mask = context.tensors.load(cond_field.mask.tensor_name) + mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype) + + text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask)) + + return text_conditionings + + def _concat_regional_text_conditioning( + self, text_conditionings: list[FluxTextConditioning] + ) -> FluxRegionalTextConditioning: + """Concatenate regional text conditioning data into a single conditioning tensor (with associated masks).""" + concat_t5_embeddings: list[torch.Tensor] = [] + concat_clip_embeddings: list[torch.Tensor] = [] + concat_image_masks: list[torch.Tensor] = [] + concat_t5_embedding_ranges: list[Range] = [] + concat_clip_embedding_ranges: list[Range] = [] + + cur_t5_embedding_len = 0 + cur_clip_embedding_len = 0 + for text_conditioning in text_conditionings: + concat_t5_embeddings.append(text_conditioning.t5_embeddings) + concat_clip_embeddings.append(text_conditioning.clip_embeddings) + + concat_t5_embedding_ranges.append( + Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1]) + ) + concat_clip_embedding_ranges.append( + Range( + start=cur_clip_embedding_len, + end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1], + ) + ) + + concat_image_masks.append(text_conditioning.mask) + + cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1] + cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1] + + t5_embeddings = torch.cat(concat_t5_embeddings, dim=1) + + # Initialize the txt_ids tensor. + pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape + t5_txt_ids = torch.zeros( + pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device() + ) + + return FluxRegionalTextConditioning( + t5_embeddings=t5_embeddings, + clip_embeddings=torch.cat(concat_clip_embeddings, dim=1), + t5_txt_ids=t5_txt_ids, + image_masks=torch.cat(concat_image_masks, dim=1), + t5_embedding_ranges=concat_t5_embedding_ranges, + clip_embedding_ranges=concat_clip_embedding_ranges, + ) def _run_diffusion( self, @@ -158,17 +255,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): ): inference_dtype = torch.bfloat16 - # Load the conditioning data. - pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning( - context, self.positive_text_conditioning.conditioning_name, inference_dtype - ) - neg_t5_embeddings: torch.Tensor | None = None - neg_clip_embeddings: torch.Tensor | None = None - if self.negative_text_conditioning is not None: - neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning( - context, self.negative_text_conditioning.conditioning_name, inference_dtype - ) - # Load the input latents, if provided. init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None if init_latents is not None: @@ -183,6 +269,30 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): dtype=inference_dtype, seed=self.seed, ) + b, _c, latent_h, latent_w = noise.shape + + # Load the conditioning data. + pos_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.positive_text_conditioning, + latent_height=latent_h, + latent_width=latent_w, + dtype=inference_dtype, + ) + neg_text_conditionings: list[FluxTextConditioning] | None = None + if self.negative_text_conditioning is not None: + neg_text_conditionings = self._load_text_conditioning( + context=context, + cond_field=self.negative_text_conditioning, + latent_height=latent_h, + latent_width=latent_w, + dtype=inference_dtype, + ) + + pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings) + neg_regional_text_conditioning = ( + self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None + ) transformer_info = context.models.load(self.transformer.transformer) is_schnell = "schnell" in transformer_info.config.config_path @@ -228,20 +338,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): inpaint_mask = self._prep_inpaint_mask(context, x) - b, _c, latent_h, latent_w = x.shape img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype) - pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape - pos_txt_ids = torch.zeros( - pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - neg_txt_ids: torch.Tensor | None = None - if neg_t5_embeddings is not None: - neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape - neg_txt_ids = torch.zeros( - neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device() - ) - # Pack all latent tensors. init_latents = pack(init_latents) if init_latents is not None else None inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None @@ -338,12 +436,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): model=transformer, img=x, img_ids=img_ids, - txt=pos_t5_embeddings, - txt_ids=pos_txt_ids, - vec=pos_clip_embeddings, - neg_txt=neg_t5_embeddings, - neg_txt_ids=neg_txt_ids, - neg_vec=neg_clip_embeddings, + pos_text_conditioning=pos_regional_text_conditioning, + neg_text_conditioning=neg_regional_text_conditioning, timesteps=timesteps, step_callback=self._build_step_callback(context), guidance=self.guidance, diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index af250f0f3b..cc9c68eca4 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,11 +1,11 @@ from contextlib import ExitStack -from typing import Iterator, Literal, Tuple +from typing import Iterator, Literal, Optional, Tuple import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation -from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField +from invokeai.app.invocations.fields import FieldDescriptions, FluxConditioningField, Input, InputField, TensorField from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext @@ -42,6 +42,9 @@ class FluxTextEncoderInvocation(BaseInvocation): description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models." ) prompt: str = InputField(description="Text prompt to encode.") + mask: Optional[TensorField] = InputField( + default=None, description="A mask defining the region that this conditioning prompt applies to." + ) @torch.no_grad() def invoke(self, context: InvocationContext) -> FluxConditioningOutput: @@ -54,7 +57,9 @@ class FluxTextEncoderInvocation(BaseInvocation): ) conditioning_name = context.conditioning.save(conditioning_data) - return FluxConditioningOutput.build(conditioning_name) + return FluxConditioningOutput( + conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask) + ) def _t5_encode(self, context: InvocationContext) -> torch.Tensor: t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer) diff --git a/invokeai/backend/flux/denoise.py b/invokeai/backend/flux/denoise.py index bb0e60409a..c1cb3bbb71 100644 --- a/invokeai/backend/flux/denoise.py +++ b/invokeai/backend/flux/denoise.py @@ -10,6 +10,7 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState @@ -18,14 +19,8 @@ def denoise( # model input img: torch.Tensor, img_ids: torch.Tensor, - # positive text conditioning - txt: torch.Tensor, - txt_ids: torch.Tensor, - vec: torch.Tensor, - # negative text conditioning - neg_txt: torch.Tensor | None, - neg_txt_ids: torch.Tensor | None, - neg_vec: torch.Tensor | None, + pos_text_conditioning: FluxRegionalTextConditioning, + neg_text_conditioning: FluxRegionalTextConditioning | None, # sampling parameters timesteps: list[float], step_callback: Callable[[PipelineIntermediateState], None], @@ -55,6 +50,7 @@ def denoise( # Run ControlNet models. controlnet_residuals: list[ControlNetFluxOutput] = [] for controlnet_extension in controlnet_extensions: + # FIX(ryand): Revive ControlNet functionality. controlnet_residuals.append( controlnet_extension.run_controlnet( timestep_index=step_index, diff --git a/invokeai/backend/flux/text_conditioning.py b/invokeai/backend/flux/text_conditioning.py new file mode 100644 index 0000000000..82d37688d1 --- /dev/null +++ b/invokeai/backend/flux/text_conditioning.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass + +import torch + +from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range + + +@dataclass +class FluxTextConditioning: + t5_embeddings: torch.Tensor + clip_embeddings: torch.Tensor + mask: torch.Tensor + + +@dataclass +class FluxRegionalTextConditioning: + # Concatenated text embeddings. + t5_embeddings: torch.Tensor + clip_embeddings: torch.Tensor + + t5_txt_ids: torch.Tensor + + # A binary mask indicating the regions of the image that the prompt should be applied to. + # Shape: (1, num_prompts, height, width) + # Dtype: torch.bool + image_masks: torch.Tensor + + # List of ranges that represent the embedding ranges for each mask. + # t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i]. + # clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i]. + t5_embedding_ranges: list[Range] + clip_embedding_ranges: list[Range]