Migrate DiffusersControlNetFlux from diffusers-style to BFL-style.

This commit is contained in:
Ryan Dick
2024-10-04 18:46:47 +00:00
parent 16cda33025
commit 1751c380db

View File

@@ -1,18 +1,20 @@
# This file was initially copied from:
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
# TODO(ryand): Remove this file and import the model from the diffusers package instead. I have not done this yet,
# because:
# 1. The latest changes to this model in diffusers have not yet been included in a diffusers release.
# 2. We need to sort out https://github.com/invoke-ai/InvokeAI/pull/6740 before we can bump the diffusers package.
from dataclasses import dataclass
from typing import List
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.controlnet import zero_module
from diffusers.models.modeling_utils import ModelMixin
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
@@ -21,193 +23,152 @@ class DiffusersControlNetFluxOutput:
controlnet_single_block_samples: list[torch.Tensor] | None
class DiffusersControlNetFlux(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
):
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels
# - num_layers: depth
# - num_single_layers: depth_single_blocks
# - attention_head_dim: hidden_size // num_heads
# - num_attention_heads: num_heads
# - joint_attention_dim: context_in_dim
# - pooled_projection_dim: vec_in_dim
# - guidance_embeds: guidance_embed
# - axes_dims_rope: axes_dim
class DiffusersControlNetFlux(torch.nn.Module):
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
"""
Args:
params (FluxParams): The parameters for the FLUX model.
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
'union controlnet' model and expects a mode conditioning input at runtime.
"""
super().__init__()
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
# The following modules mirror the base FLUX transformer model.
# -------------------------------------------------------------
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
self.double_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for i in range(num_layers)
for _ in range(params.depth)
]
)
self.single_transformer_blocks = nn.ModuleList(
self.single_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
)
for i in range(num_single_layers)
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
# controlnet_blocks
# The following modules are specific to the ControlNet model.
# -----------------------------------------------------------
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_transformer_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.union = num_mode is not None
if self.union:
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
self.is_union = False
if num_control_modes is not None:
self.is_union = True
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
def forward(
def other_forward(
self,
hidden_states: torch.Tensor,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor = None,
conditioning_scale: float = 1.0,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
controlnet_mode: torch.Tensor | None,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> DiffusersControlNetFluxOutput:
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
controlnet_cond (`torch.Tensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
controlnet_mode (`torch.Tensor`):
The mode tensor of shape `(batch_size, 1)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
"""
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
hidden_states = self.x_embedder(hidden_states)
img = self.img_in(img)
# add
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
# Add controlnet_cond embedding.
img = img + self.controlnet_x_embedder(controlnet_cond)
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.union:
# union mode
if controlnet_mode is None:
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
# union mode emb
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
if self.is_union:
assert controlnet_mode is not None
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
else:
assert controlnet_mode is None
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
double_block_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
double_block_samples.append(img)
block_samples = ()
for index_block, block in enumerate(self.transformer_blocks):
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
block_samples = block_samples + (hidden_states,)
img = torch.cat((txt, img), 1)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
single_block_samples: list[torch.Tensor] = []
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
single_block_samples.append(img[:, txt.shape[1] :])
single_block_samples = ()
for index_block, block in enumerate(self.single_transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
# ControlNet Block
controlnet_double_block_samples: list[torch.Tensor] = []
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
double_block_sample = controlnet_block(double_block_sample)
controlnet_double_block_samples.append(double_block_sample)
# controlnet block
controlnet_block_samples = ()
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
block_sample = controlnet_block(block_sample)
controlnet_block_samples = controlnet_block_samples + (block_sample,)
controlnet_single_block_samples = ()
controlnet_single_block_samples: list[torch.Tensor] = []
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=False
single_block_samples, self.controlnet_single_blocks, strict=True
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
# scaling
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
controlnet_single_block_samples = (
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
)
controlnet_single_block_samples.append(single_block_sample)
return DiffusersControlNetFluxOutput(
controlnet_block_samples=controlnet_block_samples,
controlnet_single_block_samples=controlnet_single_block_samples,
controlnet_block_samples=controlnet_double_block_samples or None,
controlnet_single_block_samples=controlnet_single_block_samples or None,
)