Pass RegionalPromptingExtension down to the CustomDoubleStreamBlockProcessor in FLUX.

This commit is contained in:
Ryan Dick
2024-11-20 19:48:04 +00:00
parent 85c616fa34
commit fda7aaa7ca
5 changed files with 159 additions and 131 deletions

View File

@@ -4,7 +4,6 @@ from typing import Callable, Iterator, Optional, Tuple
import numpy as np
import numpy.typing as npt
import torch
import torchvision
import torchvision.transforms as tv_transforms
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@@ -31,6 +30,7 @@ from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlN
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.ip_adapter.xlabs_ip_adapter_flux import XlabsIpAdapterFlux
@@ -43,15 +43,14 @@ from invokeai.backend.flux.sampling_utils import (
pack,
unpack,
)
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.flux.text_conditioning import FluxTextConditioning
from invokeai.backend.lora.conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
from invokeai.backend.lora.lora_patcher import LoRAPatcher
from invokeai.backend.model_manager.config import ModelFormat
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo, Range
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import FLUXConditioningInfo
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
@invocation(
@@ -142,113 +141,6 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
@staticmethod
def _preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask
def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
return text_conditionings
def _concat_regional_text_conditioning(
self, text_conditionings: list[FluxTextConditioning]
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_clip_embeddings: list[torch.Tensor] = []
concat_image_masks: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
concat_clip_embedding_ranges: list[Range] = []
cur_t5_embedding_len = 0
cur_clip_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_clip_embeddings.append(text_conditioning.clip_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
concat_clip_embedding_ranges.append(
Range(
start=cur_clip_embedding_len,
end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1],
)
)
concat_image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
t5_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
)
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=torch.cat(concat_clip_embeddings, dim=1),
t5_txt_ids=t5_txt_ids,
image_masks=torch.cat(concat_image_masks, dim=1),
t5_embedding_ranges=concat_t5_embedding_ranges,
clip_embedding_ranges=concat_clip_embedding_ranges,
)
def _run_diffusion(
self,
context: InvocationContext,
@@ -288,10 +180,11 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
latent_width=latent_w,
dtype=inference_dtype,
)
pos_regional_text_conditioning = self._concat_regional_text_conditioning(pos_text_conditionings)
neg_regional_text_conditioning = (
self._concat_regional_text_conditioning(neg_text_conditionings) if neg_text_conditionings else None
pos_regional_prompting_extension = RegionalPromptingExtension.from_text_conditioning(pos_text_conditionings)
neg_regional_prompting_extension = (
RegionalPromptingExtension.from_text_conditioning(neg_text_conditionings)
if neg_text_conditionings
else None
)
transformer_info = context.models.load(self.transformer.transformer)
@@ -436,8 +329,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
model=transformer,
img=x,
img_ids=img_ids,
pos_text_conditioning=pos_regional_text_conditioning,
neg_text_conditioning=neg_regional_text_conditioning,
pos_regional_prompting_extension=pos_regional_prompting_extension,
neg_regional_prompting_extension=neg_regional_prompting_extension,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
@@ -451,6 +344,39 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
x = unpack(x.float(), self.height, self.width)
return x
def _load_text_conditioning(
self,
context: InvocationContext,
cond_field: FluxConditioningField | list[FluxConditioningField],
latent_height: int,
latent_width: int,
dtype: torch.dtype,
) -> list[FluxTextConditioning]:
"""Load text conditioning data from a FluxConditioningField or a list of FluxConditioningFields."""
# Normalize to a list of FluxConditioningFields.
cond_list = [cond_field] if isinstance(cond_field, FluxConditioningField) else cond_field
text_conditionings: list[FluxTextConditioning] = []
for cond_field in cond_list:
# Load the text embeddings.
cond_data = context.conditioning.load(cond_field.conditioning_name)
assert len(cond_data.conditionings) == 1
flux_conditioning = cond_data.conditionings[0]
assert isinstance(flux_conditioning, FLUXConditioningInfo)
flux_conditioning = flux_conditioning.to(dtype=dtype)
t5_embeddings = flux_conditioning.t5_embeds
clip_embeddings = flux_conditioning.clip_embeds
# Load the mask, if provided.
mask: Optional[torch.Tensor] = None
if cond_field.mask is not None:
mask = context.tensors.load(cond_field.mask.tensor_name)
mask = RegionalPromptingExtension.preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype)
text_conditionings.append(FluxTextConditioning(t5_embeddings, clip_embeddings, mask))
return text_conditionings
@classmethod
def prep_cfg_scale(
cls, cfg_scale: float | list[float], timesteps: list[float], cfg_scale_start_step: int, cfg_scale_end_step: int

View File

@@ -1,6 +1,7 @@
import einops
import torch
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.math import attention
from invokeai.backend.flux.modules.layers import DoubleStreamBlock
@@ -63,6 +64,7 @@ class CustomDoubleStreamBlockProcessor:
vec: torch.Tensor,
pe: torch.Tensor,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> tuple[torch.Tensor, torch.Tensor]:
"""A custom implementation of DoubleStreamBlock.forward() with additional features:
- IP-Adapter support

View File

@@ -7,10 +7,10 @@ from tqdm import tqdm
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -19,8 +19,8 @@ def denoise(
# model input
img: torch.Tensor,
img_ids: torch.Tensor,
pos_text_conditioning: FluxRegionalTextConditioning,
neg_text_conditioning: FluxRegionalTextConditioning | None,
pos_regional_prompting_extension: RegionalPromptingExtension,
neg_regional_prompting_extension: RegionalPromptingExtension | None,
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
@@ -50,16 +50,16 @@ def denoise(
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
# FIX(ryand): Revive ControlNet functionality.
# TODO(ryand): Think about how to handle regional prompting with ControlNet.
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step_index,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
)
@@ -74,9 +74,9 @@ def denoise(
pred = model(
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
txt=pos_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=pos_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=pos_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,
@@ -84,6 +84,7 @@ def denoise(
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
ip_adapter_extensions=pos_ip_adapter_extensions,
regional_prompting_extension=pos_regional_prompting_extension,
)
step_cfg_scale = cfg_scale[step_index]
@@ -93,15 +94,15 @@ def denoise(
# TODO(ryand): Add option to run positive and negative predictions in a single batch for better performance
# on systems with sufficient VRAM.
if neg_txt is None or neg_txt_ids is None or neg_vec is None:
if neg_regional_prompting_extension is None:
raise ValueError("Negative text conditioning is required when cfg_scale is not 1.0.")
neg_pred = model(
img=img,
img_ids=img_ids,
txt=neg_txt,
txt_ids=neg_txt_ids,
y=neg_vec,
txt=neg_regional_prompting_extension.regional_text_conditioning.t5_embeddings,
txt_ids=neg_regional_prompting_extension.regional_text_conditioning.t5_txt_ids,
y=neg_regional_prompting_extension.regional_text_conditioning.clip_embeddings,
timesteps=t_vec,
guidance=guidance_vec,
timestep_index=step_index,

View File

@@ -0,0 +1,96 @@
from typing import Optional
import torch
import torchvision
from invokeai.backend.flux.text_conditioning import FluxRegionalTextConditioning, FluxTextConditioning
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Range
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.mask import to_standard_float_mask
class RegionalPromptingExtension:
"""A class for managing regional prompting with FLUX."""
def __init__(self, regional_text_conditioning: FluxRegionalTextConditioning):
self.regional_text_conditioning = regional_text_conditioning
@classmethod
def from_text_conditioning(cls, text_conditioning: list[FluxTextConditioning]):
return cls(regional_text_conditioning=cls._concat_regional_text_conditioning(text_conditioning))
@classmethod
def _concat_regional_text_conditioning(
cls,
text_conditionings: list[FluxTextConditioning],
) -> FluxRegionalTextConditioning:
"""Concatenate regional text conditioning data into a single conditioning tensor (with associated masks)."""
concat_t5_embeddings: list[torch.Tensor] = []
concat_clip_embeddings: list[torch.Tensor] = []
concat_image_masks: list[torch.Tensor] = []
concat_t5_embedding_ranges: list[Range] = []
concat_clip_embedding_ranges: list[Range] = []
cur_t5_embedding_len = 0
cur_clip_embedding_len = 0
for text_conditioning in text_conditionings:
concat_t5_embeddings.append(text_conditioning.t5_embeddings)
concat_clip_embeddings.append(text_conditioning.clip_embeddings)
concat_t5_embedding_ranges.append(
Range(start=cur_t5_embedding_len, end=cur_t5_embedding_len + text_conditioning.t5_embeddings.shape[1])
)
concat_clip_embedding_ranges.append(
Range(
start=cur_clip_embedding_len,
end=cur_clip_embedding_len + text_conditioning.clip_embeddings.shape[1],
)
)
concat_image_masks.append(text_conditioning.mask)
cur_t5_embedding_len += text_conditioning.t5_embeddings.shape[1]
cur_clip_embedding_len += text_conditioning.clip_embeddings.shape[1]
t5_embeddings = torch.cat(concat_t5_embeddings, dim=1)
# Initialize the txt_ids tensor.
pos_bs, pos_t5_seq_len, _ = t5_embeddings.shape
t5_txt_ids = torch.zeros(
pos_bs, pos_t5_seq_len, 3, dtype=t5_embeddings.dtype, device=TorchDevice.choose_torch_device()
)
return FluxRegionalTextConditioning(
t5_embeddings=t5_embeddings,
clip_embeddings=torch.cat(concat_clip_embeddings, dim=1),
t5_txt_ids=t5_txt_ids,
image_masks=torch.cat(concat_image_masks, dim=1),
t5_embedding_ranges=concat_t5_embedding_ranges,
clip_embedding_ranges=concat_clip_embedding_ranges,
)
@staticmethod
def preprocess_regional_prompt_mask(
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
Returns:
torch.Tensor: The processed mask. shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=dtype)
mask = to_standard_float_mask(mask, out_dtype=dtype)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
# Add a batch dimension to the mask, because torchvision expects shape (batch, channels, h, w).
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
resized_mask = tf(mask)
return resized_mask

View File

@@ -6,6 +6,7 @@ import torch
from torch import Tensor, nn
from invokeai.backend.flux.custom_block_processor import CustomDoubleStreamBlockProcessor
from invokeai.backend.flux.extensions.regional_prompting_extension import RegionalPromptingExtension
from invokeai.backend.flux.extensions.xlabs_ip_adapter_extension import XLabsIPAdapterExtension
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
@@ -95,6 +96,7 @@ class Flux(nn.Module):
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
ip_adapter_extensions: list[XLabsIPAdapterExtension],
regional_prompting_extension: RegionalPromptingExtension,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -128,6 +130,7 @@ class Flux(nn.Module):
vec=vec,
pe=pe,
ip_adapter_extensions=ip_adapter_extensions,
regional_prompting_extension=regional_prompting_extension,
)
if controlnet_double_block_residuals is not None: