WIP - Pass prompt masks to FLUX model during denoising.

This commit is contained in:
Ryan Dick
2024-11-20 18:51:43 +00:00
parent 1948ffe106
commit 85c616fa34
5 changed files with 186 additions and 54 deletions

View File

@@ -250,6 +250,11 @@ class FluxConditioningField(BaseModel):
"""A conditioning tensor primitive value"""
conditioning_name: str = Field(description="The name of conditioning tensor")
mask: Optional[TensorField] = Field(
default=None,
description="The mask associated with this conditioning tensor. Excluded regions should be set to False, "
"included regions should be set to True.",
)
class SD3ConditioningField(BaseModel):

View File

@@ -4,6 +4,7 @@ from typing import Callable, Iterator, Optional, Tuple
import numpy as np
import numpy.typing as npt
import torch
import torchvision
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@@ -42,13 +43,15 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
@invocation(
@@ -87,10 +90,10 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
input=Input.Connection,
title="Transformer",
)
positive_text_conditioning: FluxConditioningField = InputField(
positive_text_conditioning: FluxConditioningField | list[FluxConditioningField] = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection
)
negative_text_conditioning: FluxConditioningField | None = InputField(
negative_text_conditioning: FluxConditioningField | list[FluxConditioningField] | None = InputField(
default=None,
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
@@ -139,18 +142,112 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@staticmethod
def _preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
def _load_text_conditioning(
self, context: InvocationContext, conditioning_name: str, dtype: torch.dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
# Load the conditioning data.
cond_data = context.conditioning.load(conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
return t5_embeddings, clip_embeddings
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
return text_conditionings
def _concat_regional_text_conditioning(
self, text_conditionings: list[FluxTextConditioning]
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_clip_embeddings: list[torch.Tensor] = []
concat_image_masks: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
concat_clip_embedding_ranges: list[Range] = []
cur_t5_embedding_len = 0
cur_clip_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_clip_embeddings.append(text_conditioning.clip_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
concat_clip_embedding_ranges.append(
Range(
start=cur_clip_embedding_len,
end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1],
)
)
concat_image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
t5_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
)
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=torch.cat(concat_clip_embeddings, dim=1),
t5_txt_ids=t5_txt_ids,
image_masks=torch.cat(concat_image_masks, dim=1),
t5_embedding_ranges=concat_t5_embedding_ranges,
clip_embedding_ranges=concat_clip_embedding_ranges,
)
def _run_diffusion(
self,
@@ -158,17 +255,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
):
inference_dtype = torch.bfloat16
# Load the conditioning data.
pos_t5_embeddings, pos_clip_embeddings = self._load_text_conditioning(
context, self.positive_text_conditioning.conditioning_name, inference_dtype
)
neg_t5_embeddings: torch.Tensor | None = None
neg_clip_embeddings: torch.Tensor | None = None
if self.negative_text_conditioning is not None:
neg_t5_embeddings, neg_clip_embeddings = self._load_text_conditioning(
context, self.negative_text_conditioning.conditioning_name, inference_dtype
)
# Load the input latents, if provided.
init_latents = context.tensors.load(self.latents.latents_name) if self.latents else None
if init_latents is not None:
@@ -183,6 +269,30 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
dtype=inference_dtype,
seed=self.seed,
)
b, _c, latent_h, latent_w = noise.shape
# Load the conditioning data.
pos_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.positive_text_conditioning,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
)
neg_text_conditionings: list[FluxTextConditioning] | None = None
if self.negative_text_conditioning is not None:
neg_text_conditionings = self._load_text_conditioning(
context=context,
cond_field=self.negative_text_conditioning,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
)
pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings)
neg_regional_text_conditioning = (
self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None
)
transformer_info = context.models.load(self.transformer.transformer)
is_schnell = "schnell" in transformer_info.config.config_path
@@ -228,20 +338,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
pos_bs, pos_t5_seq_len, _ = pos_t5_embeddings.shape
pos_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
neg_txt_ids: torch.Tensor | None = None
if neg_t5_embeddings is not None:
neg_bs, neg_t5_seq_len, _ = neg_t5_embeddings.shape
neg_txt_ids = torch.zeros(
neg_bs, neg_t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device()
)
# Pack all latent tensors.
init_latents = pack(init_latents) if init_latents is not None else None
inpaint_mask = pack(inpaint_mask) if inpaint_mask is not None else None
@@ -338,12 +436,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
img=x,
img_ids=img_ids,
txt=pos_t5_embeddings,
txt_ids=pos_txt_ids,
vec=pos_clip_embeddings,
neg_txt=neg_t5_embeddings,
neg_txt_ids=neg_txt_ids,
neg_vec=neg_clip_embeddings,
pos_text_conditioning=pos_regional_text_conditioning,
neg_text_conditioning=neg_regional_text_conditioning,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,

View File

@@ -1,11 +1,11 @@
from contextlib import ExitStack
from typing import Iterator, Literal, Tuple
from typing import Iterator, Literal, Optional, Tuple
import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField
from invokeai.app.invocations.fields import FieldDescriptions, FluxConditioningField, Input, InputField, TensorField
from invokeai.app.invocations.model import CLIPField, T5EncoderField
from invokeai.app.invocations.primitives import FluxConditioningOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
@@ -42,6 +42,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
description="Max sequence length for the T5 encoder. Expected to be 256 for FLUX schnell models and 512 for FLUX dev models."
)
prompt: str = InputField(description="Text prompt to encode.")
mask: Optional[TensorField] = InputField(
default=None, description="A mask defining the region that this conditioning prompt applies to."
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> FluxConditioningOutput:
@@ -54,7 +57,9 @@ class FluxTextEncoderInvocation(BaseInvocation):
)
conditioning_name = context.conditioning.save(conditioning_data)
return FluxConditioningOutput.build(conditioning_name)
return FluxConditioningOutput(
conditioning=FluxConditioningField(conditioning_name=conditioning_name, mask=self.mask)
)
def _t5_encode(self, context: InvocationContext) -> torch.Tensor:
t5_tokenizer_info = context.models.load(self.t5_encoder.tokenizer)

View File

@@ -10,6 +10,7 @@ from invokeai.backend.flux.extensions.instantx_controlnet_extension import Insta
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -18,14 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
# positive text conditioning
txt: torch.Tensor,
txt_ids: torch.Tensor,
vec: torch.Tensor,
# negative text conditioning
neg_txt: torch.Tensor | None,
neg_txt_ids: torch.Tensor | None,
neg_vec: torch.Tensor | None,
pos_text_conditioning: FluxRegionalTextConditioning,
neg_text_conditioning: FluxRegionalTextConditioning | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -55,6 +50,7 @@ def denoise(
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
# FIX(ryand): Revive ControlNet functionality.
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step_index,

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
import torch
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
@dataclass
class FluxTextConditioning:
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
mask: torch.Tensor
@dataclass
class FluxRegionalTextConditioning:
# Concatenated text embeddings.
t5_embeddings: torch.Tensor
clip_embeddings: torch.Tensor
t5_txt_ids: torch.Tensor
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
image_masks: torch.Tensor
# List of ranges that represent the embedding ranges for each mask.
# t5_embedding_ranges[i] contains the range of the t5 embeddings that correspond to image_masks[i].
# clip_embedding_ranges[i] contains the range of the clip embeddings that correspond to image_masks[i].
t5_embedding_ranges: list[Range]
clip_embedding_ranges: list[Range]