mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-17 04:01:23 -05:00
214 lines
9.0 KiB
Python
214 lines
9.0 KiB
Python
# This file was initially copied from:
|
|
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
|
|
# TODO(ryand): Remove this file and import the model from the diffusers package instead. I have not done this yet,
|
|
# because:
|
|
# 1. The latest changes to this model in diffusers have not yet been included in a diffusers release.
|
|
# 2. We need to sort out https://github.com/invoke-ai/InvokeAI/pull/6740 before we can bump the diffusers package.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
|
from diffusers.models.controlnet import zero_module
|
|
from diffusers.models.modeling_utils import ModelMixin
|
|
|
|
|
|
@dataclass
|
|
class DiffusersControlNetFluxOutput:
|
|
controlnet_block_samples: list[torch.Tensor] | None
|
|
controlnet_single_block_samples: list[torch.Tensor] | None
|
|
|
|
|
|
class DiffusersControlNetFlux(ModelMixin, ConfigMixin):
|
|
@register_to_config
|
|
def __init__(
|
|
self,
|
|
patch_size: int = 1,
|
|
in_channels: int = 64,
|
|
num_layers: int = 19,
|
|
num_single_layers: int = 38,
|
|
attention_head_dim: int = 128,
|
|
num_attention_heads: int = 24,
|
|
joint_attention_dim: int = 4096,
|
|
pooled_projection_dim: int = 768,
|
|
guidance_embeds: bool = False,
|
|
axes_dims_rope: List[int] = [16, 56, 56],
|
|
num_mode: int = None,
|
|
):
|
|
super().__init__()
|
|
self.out_channels = in_channels
|
|
self.inner_dim = num_attention_heads * attention_head_dim
|
|
|
|
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
|
text_time_guidance_cls = (
|
|
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
|
)
|
|
self.time_text_embed = text_time_guidance_cls(
|
|
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
|
)
|
|
|
|
self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
|
|
self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
|
|
|
|
self.transformer_blocks = nn.ModuleList(
|
|
[
|
|
FluxTransformerBlock(
|
|
dim=self.inner_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
)
|
|
for i in range(num_layers)
|
|
]
|
|
)
|
|
|
|
self.single_transformer_blocks = nn.ModuleList(
|
|
[
|
|
FluxSingleTransformerBlock(
|
|
dim=self.inner_dim,
|
|
num_attention_heads=num_attention_heads,
|
|
attention_head_dim=attention_head_dim,
|
|
)
|
|
for i in range(num_single_layers)
|
|
]
|
|
)
|
|
|
|
# controlnet_blocks
|
|
self.controlnet_blocks = nn.ModuleList([])
|
|
for _ in range(len(self.transformer_blocks)):
|
|
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
|
|
|
self.controlnet_single_blocks = nn.ModuleList([])
|
|
for _ in range(len(self.single_transformer_blocks)):
|
|
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
|
|
|
|
self.union = num_mode is not None
|
|
if self.union:
|
|
self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
|
|
|
|
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
controlnet_cond: torch.Tensor,
|
|
controlnet_mode: torch.Tensor = None,
|
|
conditioning_scale: float = 1.0,
|
|
encoder_hidden_states: torch.Tensor = None,
|
|
pooled_projections: torch.Tensor = None,
|
|
timestep: torch.LongTensor = None,
|
|
img_ids: torch.Tensor = None,
|
|
txt_ids: torch.Tensor = None,
|
|
guidance: torch.Tensor = None,
|
|
) -> DiffusersControlNetFluxOutput:
|
|
"""
|
|
Args:
|
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
|
Input `hidden_states`.
|
|
controlnet_cond (`torch.Tensor`):
|
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
|
controlnet_mode (`torch.Tensor`):
|
|
The mode tensor of shape `(batch_size, 1)`.
|
|
conditioning_scale (`float`, defaults to `1.0`):
|
|
The scale factor for ControlNet outputs.
|
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
|
from the embeddings of input conditions.
|
|
timestep ( `torch.LongTensor`):
|
|
Used to indicate denoising step.
|
|
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
|
A list of tensors that if specified are added to the residuals of transformer blocks.
|
|
"""
|
|
|
|
hidden_states = self.x_embedder(hidden_states)
|
|
|
|
# add
|
|
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
|
|
|
|
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
if guidance is not None:
|
|
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
else:
|
|
guidance = None
|
|
temb = (
|
|
self.time_text_embed(timestep, pooled_projections)
|
|
if guidance is None
|
|
else self.time_text_embed(timestep, guidance, pooled_projections)
|
|
)
|
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
|
|
if self.union:
|
|
# union mode
|
|
if controlnet_mode is None:
|
|
raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
|
|
# union mode emb
|
|
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
|
encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
|
|
txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
|
|
|
|
if txt_ids.ndim == 3:
|
|
logger.warning(
|
|
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
)
|
|
txt_ids = txt_ids[0]
|
|
if img_ids.ndim == 3:
|
|
logger.warning(
|
|
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
)
|
|
img_ids = img_ids[0]
|
|
|
|
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
image_rotary_emb = self.pos_embed(ids)
|
|
|
|
block_samples = ()
|
|
for index_block, block in enumerate(self.transformer_blocks):
|
|
encoder_hidden_states, hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
temb=temb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
block_samples = block_samples + (hidden_states,)
|
|
|
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
|
|
single_block_samples = ()
|
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
hidden_states = block(
|
|
hidden_states=hidden_states,
|
|
temb=temb,
|
|
image_rotary_emb=image_rotary_emb,
|
|
)
|
|
single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
|
|
|
|
# controlnet block
|
|
controlnet_block_samples = ()
|
|
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks, strict=False):
|
|
block_sample = controlnet_block(block_sample)
|
|
controlnet_block_samples = controlnet_block_samples + (block_sample,)
|
|
|
|
controlnet_single_block_samples = ()
|
|
for single_block_sample, controlnet_block in zip(
|
|
single_block_samples, self.controlnet_single_blocks, strict=False
|
|
):
|
|
single_block_sample = controlnet_block(single_block_sample)
|
|
controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
|
|
|
|
# scaling
|
|
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
|
|
controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
|
|
|
|
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
|
|
controlnet_single_block_samples = (
|
|
None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
|
|
)
|
|
|
|
return DiffusersControlNetFluxOutput(
|
|
controlnet_block_samples=controlnet_block_samples,
|
|
controlnet_single_block_samples=controlnet_single_block_samples,
|
|
)
|