Add support for FLUX ControlNet models (XLabs and InstantX) (#7070)

## Summary

Add support for FLUX ControlNet models (XLabs and InstantX).

## QA Instructions

- [x] SD1 and SDXL ControlNets, since the ModelLoaderRegistry calls were
changed.
- [x] Single Xlabs controlnet
- [x] Single InstantX union controlnet
- [x] Single InstantX controlnet
- [x] Single Shakker Labs Union controlnet
- [x] Multiple controlnets
- [x] Weight, start, end params all work as expected
- [x] Can be used with image-to-image and inpainting.
- [x] Clear error message if no VAE is passed when using InstantX
controlnet.
- [x] Install InstantX ControlNet in diffusers format from HF repo
(`InstantX/FLUX.1-dev-Controlnet-Union`)
- [x] Test all FLUX ControlNet starter models
## Merge Plan

No special instructions.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [x] _Tests added / updated (if applicable)_
- [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick
2024-10-10 12:37:09 -04:00
committed by GitHub
30 changed files with 2245 additions and 47 deletions

View File

@@ -192,6 +192,7 @@ class FieldDescriptions:
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
class ImageField(BaseModel):

View File

@@ -0,0 +1,99 @@
from pydantic import BaseModel, Field, field_validator, model_validator
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
Classification,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
class FluxControlNetField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
@invocation_output("flux_controlnet_output")
class FluxControlNetOutput(BaseInvocationOutput):
"""FLUX ControlNet info"""
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
@invocation(
"flux_controlnet",
title="FLUX ControlNet",
tags=["controlnet", "flux"],
category="controlnet",
version="1.0.0",
classification=Classification.Prototype,
)
class FluxControlNetInvocation(BaseInvocation):
"""Collect FLUX ControlNet info to pass to other nodes."""
image: ImageField = InputField(description="The control image")
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float | list[float] = InputField(
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
)
begin_step_percent: float = InputField(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
@field_validator("control_weight")
@classmethod
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
validate_weights(v)
return v
@model_validator(mode="after")
def validate_begin_end_step_percent(self):
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
return self
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
return FluxControlNetOutput(
control=FluxControlNetField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
instantx_control_mode=self.instantx_control_mode,
),
)

View File

@@ -16,11 +16,16 @@ from invokeai.app.invocations.fields import (
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import TransformerField
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
from invokeai.app.invocations.model import TransformerField, VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.denoise import denoise
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.sampling_utils import (
clip_timestep_schedule_fractional,
@@ -44,7 +49,7 @@ from invokeai.backend.util.devices import TorchDevice
title="FLUX Denoise",
tags=["image", "flux"],
category="image",
version="3.0.0",
version="3.1.0",
classification=Classification.Prototype,
)
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
@@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
)
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
default=None, input=Input.Connection, description="ControlNet models."
)
controlnet_vae: VAEField | None = InputField(
description=FieldDescriptions.vae,
input=Input.Connection,
)
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
@@ -167,8 +179,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
inpaint_mask = self._prep_inpaint_mask(context, x)
b, _c, h, w = x.shape
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
b, _c, latent_h, latent_w = x.shape
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
bs, t5_seq_len, _ = t5_embeddings.shape
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
@@ -192,12 +204,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
noise=noise,
)
with (
transformer_info.model_on_device() as (cached_weights, transformer),
ExitStack() as exit_stack,
):
assert isinstance(transformer, Flux)
with ExitStack() as exit_stack:
# Prepare ControlNet extensions.
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
controlnet_extensions = self._prep_controlnet_extensions(
context=context,
exit_stack=exit_stack,
latent_height=latent_h,
latent_width=latent_w,
dtype=inference_dtype,
device=x.device,
)
# Load the transformer model.
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
assert isinstance(transformer, Flux)
config = transformer_info.config
assert config is not None
@@ -242,6 +263,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
step_callback=self._build_step_callback(context),
guidance=self.guidance,
inpaint_extension=inpaint_extension,
controlnet_extensions=controlnet_extensions,
)
x = unpack(x.float(), self.height, self.width)
@@ -288,6 +310,104 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
# `latents`.
return mask.expand_as(latents)
def _prep_controlnet_extensions(
self,
context: InvocationContext,
exit_stack: ExitStack,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
# Normalize the controlnet input to list[ControlField].
controlnets: list[FluxControlNetField]
if self.control is None:
controlnets = []
elif isinstance(self.control, FluxControlNetField):
controlnets = [self.control]
elif isinstance(self.control, list):
controlnets = self.control
else:
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
# minimize peak memory.
# First, load the ControlNet models so that we can determine the ControlNet types.
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
# Calculate the controlnet conditioning tensors.
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
# keep peak memory down.
controlnet_conds: list[torch.Tensor] = []
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
image = context.images.get_pil(controlnet.image.image_name)
if isinstance(controlnet_model.model, InstantXControlNetFlux):
if self.controlnet_vae is None:
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
vae_info = context.models.load(self.controlnet_vae.vae)
controlnet_conds.append(
InstantXControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
vae_info=vae_info,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
controlnet_conds.append(
XLabsControlNetExtension.prepare_controlnet_cond(
controlnet_image=image,
latent_height=latent_height,
latent_width=latent_width,
dtype=dtype,
device=device,
resize_mode=controlnet.resize_mode,
)
)
# Finally, load the ControlNet models and initialize the ControlNet extensions.
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
for controlnet, controlnet_cond, controlnet_model in zip(
controlnets, controlnet_conds, controlnet_models, strict=True
):
model = exit_stack.enter_context(controlnet_model)
if isinstance(model, XLabsControlNetFlux):
controlnet_extensions.append(
XLabsControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
elif isinstance(model, InstantXControlNetFlux):
instantx_control_mode: torch.Tensor | None = None
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
controlnet_extensions.append(
InstantXControlNetExtension(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=controlnet.control_weight,
begin_step_percent=controlnet.begin_step_percent,
end_step_percent=controlnet.end_step_percent,
)
)
else:
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
return controlnet_extensions
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
for lora in self.transformer.loras:
lora_info = context.models.load(lora.lora)

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
import torch
@dataclass
class ControlNetFluxOutput:
single_block_residuals: list[torch.Tensor] | None
double_block_residuals: list[torch.Tensor] | None
def apply_weight(self, weight: float):
if self.single_block_residuals is not None:
for i in range(len(self.single_block_residuals)):
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
if self.double_block_residuals is not None:
for i in range(len(self.double_block_residuals)):
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
def add_tensor_lists_elementwise(
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
) -> list[torch.Tensor] | None:
"""Add two tensor lists elementwise that could be None."""
if list1 is None and list2 is None:
return None
if list1 is None:
return list2
if list2 is None:
return list1
new_list: list[torch.Tensor] = []
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
new_list.append(list1_tensor + list2_tensor)
return new_list
def add_controlnet_flux_outputs(
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
) -> ControlNetFluxOutput:
return ControlNetFluxOutput(
single_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
),
double_block_residuals=add_tensor_lists_elementwise(
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
),
)
def sum_controlnet_flux_outputs(
controlnet_outputs: list[ControlNetFluxOutput],
) -> ControlNetFluxOutput:
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
for controlnet_output in controlnet_outputs:
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
return controlnet_output_sum

View File

@@ -0,0 +1,180 @@
# This file was initially copied from:
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
from dataclasses import dataclass
import torch
import torch.nn as nn
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class InstantXControlNetFluxOutput:
controlnet_block_samples: list[torch.Tensor] | None
controlnet_single_block_samples: list[torch.Tensor] | None
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
# - Diffusers: BFL
# - in_channels: in_channels
# - num_layers: depth
# - num_single_layers: depth_single_blocks
# - attention_head_dim: hidden_size // num_heads
# - num_attention_heads: num_heads
# - joint_attention_dim: context_in_dim
# - pooled_projection_dim: vec_in_dim
# - guidance_embeds: guidance_embed
# - axes_dims_rope: axes_dim
class InstantXControlNetFlux(torch.nn.Module):
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
"""
Args:
params (FluxParams): The parameters for the FLUX model.
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
'union controlnet' model and expects a mode conditioning input at runtime.
"""
super().__init__()
# The following modules mirror the base FLUX transformer model.
# -------------------------------------------------------------
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
# The following modules are specific to the ControlNet model.
# -----------------------------------------------------------
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.double_blocks)):
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(len(self.single_blocks)):
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
self.is_union = False
if num_control_modes is not None:
self.is_union = True
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
def forward(
self,
controlnet_cond: torch.Tensor,
controlnet_mode: torch.Tensor | None,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> InstantXControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
# Add controlnet_cond embedding.
img = img + self.controlnet_x_embedder(controlnet_cond)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
if self.is_union:
if controlnet_mode is None:
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
# We've chosen to use a zero-embedding in this case.
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
else:
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
else:
assert controlnet_mode is None
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
double_block_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
double_block_samples.append(img)
img = torch.cat((txt, img), 1)
single_block_samples: list[torch.Tensor] = []
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
single_block_samples.append(img[:, txt.shape[1] :])
# ControlNet Block
controlnet_double_block_samples: list[torch.Tensor] = []
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
double_block_sample = controlnet_block(double_block_sample)
controlnet_double_block_samples.append(double_block_sample)
controlnet_single_block_samples: list[torch.Tensor] = []
for single_block_sample, controlnet_block in zip(
single_block_samples, self.controlnet_single_blocks, strict=True
):
single_block_sample = controlnet_block(single_block_sample)
controlnet_single_block_samples.append(single_block_sample)
return InstantXControlNetFluxOutput(
controlnet_block_samples=controlnet_double_block_samples or None,
controlnet_single_block_samples=controlnet_single_block_samples or None,
)

View File

@@ -0,0 +1,295 @@
from typing import Any, Dict
import torch
from invokeai.backend.flux.model import FluxParams
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an XLabs ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"input_hint_block.0.bias",
"input_hint_block.0.weight",
"pos_embed_input.bias",
"pos_embed_input.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
"""Is the state dict for an InstantX ControlNet model?
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
"""
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
expected_keys = {
"controlnet_blocks.0.bias",
"controlnet_blocks.0.weight",
"controlnet_x_embedder.bias",
"controlnet_x_embedder.weight",
}
if expected_keys.issubset(sd.keys()):
return True
return False
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
"""Fuse weights along dimension 0.
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
"""
# TODO(ryand): Double check dim=0 is correct.
return torch.cat(t, dim=0)
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], double_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a double block from diffusers format to BFL format."""
to_prefix = f"double_blocks.{double_block_index}"
from_prefix = f"transformer_blocks.{double_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
return new_sd
# txt_attn.qkv
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
)
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
)
# img_attn.qkv
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
)
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# img_attn
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
# img_mlp
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
# img_mod
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
# txt_attn
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
# txt_mlp
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
# txt_mod
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], single_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a single block from diffusers format to BFL format."""
to_prefix = f"single_blocks.{single_block_index}"
from_prefix = f"single_transformer_blocks.{single_block_index}"
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.to_q.bias" not in sd:
return new_sd
# linear1 (qkv)
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
sd.pop(f"{from_prefix}.proj_mlp.bias"),
)
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
sd.pop(f"{from_prefix}.proj_mlp.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# linear2
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
# modulation
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
# norm
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.scale",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
InstantXControlNetFlux model.
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
original state dict to the format expected by InstantXControlNetFlux.
"""
# Shallow copy sd so that we can pop keys from it without modifying the original.
sd = sd.copy()
new_sd: dict[str, torch.Tensor] = {}
# Handle basic 1-to-1 key conversions.
basic_key_map = {
# Base model keys.
# ----------------
# txt_in keys.
"context_embedder.bias": "txt_in.bias",
"context_embedder.weight": "txt_in.weight",
# guidance_in MLPEmbedder keys.
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
# vector_in MLPEmbedder keys.
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
# time_in MLPEmbedder keys.
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
# img_in keys.
"x_embedder.bias": "img_in.bias",
"x_embedder.weight": "img_in.weight",
}
for old_key, new_key in basic_key_map.items():
v = sd.pop(old_key, None)
if v is not None:
new_sd[new_key] = v
# Handle the double_blocks.
block_index = 0
while True:
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_double_block_sd) == 0:
break
new_sd.update(converted_double_block_sd)
block_index += 1
# Handle the single_blocks.
block_index = 0
while True:
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_singe_block_sd) == 0:
break
new_sd.update(converted_singe_block_sd)
block_index += 1
# Transfer controlnet keys as-is.
for k in list(sd.keys()):
if k.startswith("controlnet_"):
new_sd[k] = sd.pop(k)
# Assert that all keys have been handled.
assert len(sd) == 0
return new_sd
def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams:
"""Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this
information is all contained in the config.json file that accompanies the model. However, being apple to infer the
params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file.
"""
hidden_size = sd["img_in.weight"].shape[0]
mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
# mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible float precision
# issues. In practice, mlp_ratio is usually 4.
mlp_ratio = mlp_hidden_dim // hidden_size
head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0]
num_heads = hidden_size // head_dim
# Count the number of double blocks.
double_block_index = 0
while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd:
double_block_index += 1
# Count the number of single blocks.
single_block_index = 0
while f"single_blocks.{single_block_index}.linear1.weight" in sd:
single_block_index += 1
return FluxParams(
in_channels=sd["img_in.weight"].shape[1],
vec_in_dim=sd["vector_in.in_layer.weight"].shape[1],
context_in_dim=sd["txt_in.weight"].shape[1],
hidden_size=hidden_size,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
depth=double_block_index,
depth_single_blocks=single_block_index,
# axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
axes_dim=[16, 56, 56],
# theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
theta=10_000,
qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd,
guidance_embed="guidance_in.in_layer.weight" in sd,
)
def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None:
"""Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict.
Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes.
"""
mode_embedder_key = "controlnet_mode_embedder.weight"
if mode_embedder_key not in sd:
return None
return sd[mode_embedder_key].shape[0]

View File

@@ -0,0 +1,130 @@
# This file was initially based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
from dataclasses import dataclass
import torch
from einops import rearrange
from invokeai.backend.flux.controlnet.zero_module import zero_module
from invokeai.backend.flux.model import FluxParams
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
@dataclass
class XLabsControlNetFluxOutput:
controlnet_double_block_residuals: list[torch.Tensor] | None
class XLabsControlNetFlux(torch.nn.Module):
"""A ControlNet model for FLUX.
The architecture is very similar to the base FLUX model, with the following differences:
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
"""
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity()
)
self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = torch.nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(controlnet_depth)
]
)
# Add ControlNet blocks.
self.controlnet_blocks = torch.nn.ModuleList([])
for _ in range(controlnet_depth):
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.input_hint_block = torch.nn.Sequential(
torch.nn.Conv2d(3, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1),
torch.nn.SiLU(),
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
torch.nn.SiLU(),
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
)
def forward(
self,
img: torch.Tensor,
img_ids: torch.Tensor,
controlnet_cond: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
timesteps: torch.Tensor,
y: torch.Tensor,
guidance: torch.Tensor | None = None,
) -> XLabsControlNetFluxOutput:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples: list[torch.Tensor] = []
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples.append(img)
controlnet_block_res_samples: list[torch.Tensor] = []
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples.append(block_res_sample)
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)

View File

@@ -0,0 +1,12 @@
from typing import TypeVar
import torch
T = TypeVar("T", bound=torch.nn.Module)
def zero_module(module: T) -> T:
"""Initialize the parameters of a module to zero."""
for p in module.parameters():
torch.nn.init.zeros_(p)
return module

View File

@@ -3,7 +3,10 @@ from typing import Callable
import torch
from tqdm import tqdm
from invokeai.backend.flux.inpaint_extension import InpaintExtension
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
from invokeai.backend.flux.model import Flux
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
@@ -21,6 +24,7 @@ def denoise(
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
inpaint_extension: InpaintExtension | None,
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
):
# step 0 is the initial state
total_steps = len(timesteps) - 1
@@ -38,6 +42,30 @@ def denoise(
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
# Run ControlNet models.
controlnet_residuals: list[ControlNetFluxOutput] = []
for controlnet_extension in controlnet_extensions:
controlnet_residuals.append(
controlnet_extension.run_controlnet(
timestep_index=step - 1,
total_num_timesteps=total_steps,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
)
)
# Merge the ControlNet residuals from multiple ControlNets.
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
# tensors. Calculating the sum materializes each tensor into its own instance.
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
pred = model(
img=img,
img_ids=img_ids,
@@ -46,6 +74,8 @@ def denoise(
y=vec,
timesteps=t_vec,
guidance=guidance_vec,
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
)
preview_img = img - t_curr * pred

View File

@@ -0,0 +1,45 @@
import math
from abc import ABC, abstractmethod
from typing import List, Union
import torch
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
class BaseControlNetExtension(ABC):
def __init__(
self,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
self._weight = weight
self._begin_step_percent = begin_step_percent
self._end_step_percent = end_step_percent
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
if timestep_index < first_step or timestep_index > last_step:
return 0.0
if isinstance(self._weight, list):
return self._weight[timestep_index]
return self._weight
@abstractmethod
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput: ...

View File

@@ -0,0 +1,194 @@
import math
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
InstantXControlNetFlux,
InstantXControlNetFluxOutput,
)
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
from invokeai.backend.flux.sampling_utils import pack
from invokeai.backend.model_manager.load.load_base import LoadedModel
class InstantXControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: InstantXControlNetFlux,
controlnet_cond: torch.Tensor,
instantx_control_mode: torch.Tensor | None,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
self._controlnet_cond = controlnet_cond
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
# The control mode for InstantX ControlNet union models.
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
# Expected shape: (batch_size, 1), Expected dtype: torch.long
# If None, a zero-embedding will be used.
self._instantx_control_mode = instantx_control_mode
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: InstantXControlNetFlux,
controlnet_image: Image,
instantx_control_mode: torch.Tensor | None,
vae_info: LoadedModel,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
resized_controlnet_image = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Shift the image from [0, 1] to [-1, 1].
resized_controlnet_image = resized_controlnet_image * 2 - 1
# Run VAE encoder.
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
controlnet_cond = pack(controlnet_cond)
return cls(
model=model,
controlnet_cond=controlnet_cond,
instantx_control_mode=instantx_control_mode,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _instantx_output_to_controlnet_output(
self, instantx_output: InstantXControlNetFluxOutput
) -> ControlNetFluxOutput:
# The `interval_control` logic here is based on
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
double_block_samples = instantx_output.controlnet_block_samples
if double_block_samples:
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(double_block_samples[i // interval_control])
# Handle single block residuals.
single_block_residuals: list[torch.Tensor] = []
single_block_samples = instantx_output.controlnet_single_block_samples
if single_block_samples:
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
interval_control = int(math.ceil(interval_control))
for i in range(self._flux_transformer_num_single_blocks):
single_block_residuals.append(single_block_samples[i // interval_control])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals or None,
single_block_residuals=single_block_residuals or None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
# Make sure inputs have correct device and dtype.
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
self._instantx_control_mode = (
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
)
instantx_output: InstantXControlNetFluxOutput = self._model(
controlnet_cond=self._controlnet_cond,
controlnet_mode=self._instantx_control_mode,
img=img,
img_ids=img_ids,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -0,0 +1,150 @@
from typing import List, Union
import torch
from PIL.Image import Image
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
class XLabsControlNetExtension(BaseControlNetExtension):
def __init__(
self,
model: XLabsControlNetFlux,
controlnet_cond: torch.Tensor,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
super().__init__(
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
self._model = model
# _controlnet_cond is the control image passed to the ControlNet model.
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
self._controlnet_cond = controlnet_cond
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
self._flux_transformer_num_double_blocks = 19
self._flux_transformer_num_single_blocks = 38
@classmethod
def prepare_controlnet_cond(
cls,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return controlnet_cond
@classmethod
def from_controlnet_image(
cls,
model: XLabsControlNetFlux,
controlnet_image: Image,
latent_height: int,
latent_width: int,
dtype: torch.dtype,
device: torch.device,
resize_mode: CONTROLNET_RESIZE_VALUES,
weight: Union[float, List[float]],
begin_step_percent: float,
end_step_percent: float,
):
image_height = latent_height * LATENT_SCALE_FACTOR
image_width = latent_width * LATENT_SCALE_FACTOR
controlnet_cond = prepare_control_image(
image=controlnet_image,
do_classifier_free_guidance=False,
width=image_width,
height=image_height,
device=device,
dtype=dtype,
control_mode="balanced",
resize_mode=resize_mode,
)
# Map pixel values from [0, 1] to [-1, 1].
controlnet_cond = controlnet_cond * 2 - 1
return cls(
model=model,
controlnet_cond=controlnet_cond,
weight=weight,
begin_step_percent=begin_step_percent,
end_step_percent=end_step_percent,
)
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
# The modulo index logic used here is based on:
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
# Handle double block residuals.
double_block_residuals: list[torch.Tensor] = []
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
if xlabs_double_block_residuals is not None:
for i in range(self._flux_transformer_num_double_blocks):
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
return ControlNetFluxOutput(
double_block_residuals=double_block_residuals,
single_block_residuals=None,
)
def run_controlnet(
self,
timestep_index: int,
total_num_timesteps: int,
img: torch.Tensor,
img_ids: torch.Tensor,
txt: torch.Tensor,
txt_ids: torch.Tensor,
y: torch.Tensor,
timesteps: torch.Tensor,
guidance: torch.Tensor | None,
) -> ControlNetFluxOutput:
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
if weight < 1e-6:
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
xlabs_output: XLabsControlNetFluxOutput = self._model(
img=img,
img_ids=img_ids,
controlnet_cond=self._controlnet_cond,
txt=txt,
txt_ids=txt_ids,
timesteps=timesteps,
y=y,
guidance=guidance,
)
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
controlnet_output.apply_weight(weight)
return controlnet_output

View File

@@ -87,7 +87,9 @@ class Flux(nn.Module):
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
guidance: Tensor | None,
controlnet_double_block_residuals: list[Tensor] | None,
controlnet_single_block_residuals: list[Tensor] | None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
@@ -105,12 +107,27 @@ class Flux(nn.Module):
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
# Validate double_block_residuals shape.
if controlnet_double_block_residuals is not None:
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
for block_index, block in enumerate(self.double_blocks):
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
if controlnet_double_block_residuals is not None:
img += controlnet_double_block_residuals[block_index]
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
# Validate single_block_residuals shape.
if controlnet_single_block_residuals is not None:
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
for block_index, block in enumerate(self.single_blocks):
img = block(img, vec=vec, pe=pe)
if controlnet_single_block_residuals is not None:
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)

View File

@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
)
from invokeai.backend.model_manager.config import (
BaseModelType,
ControlNetCheckpointConfig,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
)
@ModelLoaderRegistry.register(
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
)
class ControlNetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""

View File

@@ -10,6 +10,15 @@ from safetensors.torch import load_file
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
from invokeai.app.services.config.config_default import get_config
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
from invokeai.backend.flux.util import ae_params, params
@@ -24,6 +33,8 @@ from invokeai.backend.model_manager import (
from invokeai.backend.model_manager.config import (
CheckpointConfigBase,
CLIPEmbedDiffusersConfig,
ControlNetCheckpointConfig,
ControlNetDiffusersConfig,
MainBnbQuantized4bCheckpointConfig,
MainCheckpointConfig,
MainGGUFCheckpointConfig,
@@ -293,3 +304,51 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
model.load_state_dict(sd, assign=True)
return model
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
class FluxControlnetModel(ModelLoader):
"""Class to load FLUX ControlNet models."""
def _load_model(
self,
config: AnyModelConfig,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if isinstance(config, ControlNetCheckpointConfig):
model_path = Path(config.path)
elif isinstance(config, ControlNetDiffusersConfig):
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
else:
raise ValueError(f"Unexpected ControlNet model config type: {type(config)}")
sd = load_file(model_path)
# Detect the FLUX ControlNet model type from the state dict.
if is_state_dict_xlabs_controlnet(sd):
return self._load_xlabs_controlnet(sd)
elif is_state_dict_instantx_controlnet(sd):
return self._load_instantx_controlnet(sd)
else:
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
with accelerate.init_empty_weights():
# HACK(ryand): Is it safe to assume dev here?
model = XLabsControlNetFlux(params["flux-dev"])
model.load_state_dict(sd, assign=True)
return model
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with accelerate.init_empty_weights():
model = InstantXControlNetFlux(flux_params, num_control_modes)
model.load_state_dict(sd, assign=True)
return model

View File

@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger
from invokeai.app.util.misc import uuid_string
from invokeai.backend.flux.controlnet.state_dict_utils import (
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
is_state_dict_likely_in_flux_diffusers_format,
)
@@ -116,6 +120,7 @@ class ModelProbe(object):
"CLIPModel": ModelType.CLIPEmbed,
"CLIPTextModel": ModelType.CLIPEmbed,
"T5EncoderModel": ModelType.T5Encoder,
"FluxControlNetModel": ModelType.ControlNet,
}
@classmethod
@@ -255,7 +260,19 @@ class ModelProbe(object):
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
return ModelType.LoRA
elif key.startswith(("controlnet", "control_model", "input_blocks")):
elif key.startswith(
(
"controlnet",
"control_model",
"input_blocks",
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
# delicate.
"controlnet_blocks",
)
):
return ModelType.ControlNet
elif key.startswith(("image_proj.", "ip_adapter.")):
return ModelType.IPAdapter
@@ -438,6 +455,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
"lineart": "lineart_image_processor",
"lineart_anime": "lineart_anime_image_processor",
"softedge": "hed_image_processor",
"hed": "hed_image_processor",
"shuffle": "content_shuffle_image_processor",
"pose": "dw_openpose_image_processor",
"mediapipe": "mediapipe_face_processor",
@@ -449,7 +467,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
if k in model_name:
model_name_lower = model_name.lower()
if k in model_name_lower:
return ControlAdapterDefaultSettings(preprocessor=v)
return None
@@ -623,6 +642,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
checkpoint = self.checkpoint
if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
# get_format()?
return BaseModelType.Flux
for key_name in (
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"controlnet_mid_block.bias",
@@ -844,22 +868,19 @@ class ControlNetFolderProbe(FolderProbeBase):
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
with open(config_file, "r") as file:
config = json.load(file)
if config.get("_class_name", None) == "FluxControlNetModel":
return BaseModelType.Flux
# no obvious way to distinguish between sd2-base and sd2-768
dimension = config["cross_attention_dim"]
base_model = (
BaseModelType.StableDiffusion1
if dimension == 768
else (
BaseModelType.StableDiffusion2
if dimension == 1024
else BaseModelType.StableDiffusionXL
if dimension == 2048
else None
)
)
if not base_model:
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
return base_model
if dimension == 768:
return BaseModelType.StableDiffusion1
if dimension == 1024:
return BaseModelType.StableDiffusion2
if dimension == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
class LoRAFolderProbe(FolderProbeBase):

View File

@@ -422,6 +422,13 @@ STARTER_MODELS: list[StarterModel] = [
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
type=ModelType.ControlNet,
),
StarterModel(
name="FLUX.1-dev-Controlnet-Union-Pro",
base=BaseModelType.Flux,
source="Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
description="A unified ControlNet for FLUX.1-dev model that supports 7 control modes, including canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6)",
type=ModelType.ControlNet,
),
# endregion
# region T2I Adapter
StarterModel(

View File

@@ -80,7 +80,6 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addControlLayer}
isDisabled={isFLUX}
>
{t('controlLayers.controlLayer')}
</Button>

View File

@@ -56,7 +56,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
</MenuItem>
</MenuGroup>
<MenuGroup title={t('controlLayers.layer_other')}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
{t('controlLayers.controlLayer')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>

View File

@@ -16,6 +16,7 @@ import {
controlLayerModelChanged,
controlLayerWeightChanged,
} from 'features/controlLayers/store/canvasSlice';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
@@ -42,6 +43,7 @@ export const ControlLayerControlAdapter = memo(() => {
const entityIdentifier = useEntityIdentifierContext('control_layer');
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
const filter = useEntityFilter(entityIdentifier);
const isFLUX = useAppSelector(selectIsFLUX);
const onChangeBeginEndStepPct = useCallback(
(beginEndStepPct: [number, number]) => {
@@ -117,7 +119,7 @@ export const ControlLayerControlAdapter = memo(() => {
</Flex>
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
{controlAdapter.type === 'controlnet' && (
{controlAdapter.type === 'controlnet' && !isFLUX && (
<ControlLayerControlAdapterControlMode
controlMode={controlAdapter.controlMode}
onChange={onChangeControlMode}

View File

@@ -110,10 +110,10 @@ const addControlNetToGraph = (
const controlNet = g.addNode({
id: `control_net_${id}`,
type: 'controlnet',
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
begin_step_percent: beginEndStepPct[0],
end_step_percent: beginEndStepPct[1],
control_mode: controlMode,
control_mode: model.base === 'flux' ? undefined : controlMode,
resize_mode: 'just_resize',
control_model: model,
control_weight: weight,

View File

@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
import { isNonRefinerMainModelConfig } from 'services/api/types';
import { assert } from 'tsafe';
import { addControlNets } from './addControlAdapters';
const log = logger('system');
export const buildFLUXGraph = async (
@@ -93,6 +95,7 @@ export const buildFLUXGraph = async (
> = l2i;
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
g.addEdge(modelLoader, 'vae', noise, 'controlnet_vae');
g.addEdge(modelLoader, 'vae', l2i, 'vae');
g.addEdge(modelLoader, 'clip', posCond, 'clip');
@@ -177,6 +180,24 @@ export const buildFLUXGraph = async (
);
}
const controlNetCollector = g.addNode({
type: 'collect',
id: getPrefixedId('control_net_collector'),
});
const controlNetResult = await addControlNets(
manager,
canvas.controlLayers.entities,
g,
canvas.bbox.rect,
controlNetCollector,
modelConfig.base
);
if (controlNetResult.addedControlNets > 0) {
g.addEdge(controlNetCollector, 'collection', noise, 'control');
} else {
g.deleteNode(controlNetCollector.id);
}
if (state.system.shouldUseNSFWChecker) {
canvasOutput = addNSFWChecker(g, canvasOutput);
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,30 @@
import argparse
import json
from safetensors.torch import load_file
def extract_sd_keys_and_shapes(safetensors_file: str):
sd = load_file(safetensors_file)
keys_to_shapes = {k: v.shape for k, v in sd.items()}
out_file = "keys_and_shapes.json"
with open(out_file, "w") as f:
json.dump(keys_to_shapes, f, indent=4)
print(f"Keys and shapes written to '{out_file}'.")
def main():
parser = argparse.ArgumentParser(
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
+ "dummy state dicts for use in unit tests."
)
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
args = parser.parse_args()
extract_sd_keys_and_shapes(args.safetensors_file)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,374 @@
# State dict keys and shapes for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
instantx_sd_shapes = {
"context_embedder.bias": [3072],
"context_embedder.weight": [3072, 4096],
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"controlnet_blocks.2.bias": [3072],
"controlnet_blocks.2.weight": [3072, 3072],
"controlnet_blocks.3.bias": [3072],
"controlnet_blocks.3.weight": [3072, 3072],
"controlnet_blocks.4.bias": [3072],
"controlnet_blocks.4.weight": [3072, 3072],
"controlnet_mode_embedder.weight": [10, 3072],
"controlnet_single_blocks.0.bias": [3072],
"controlnet_single_blocks.0.weight": [3072, 3072],
"controlnet_single_blocks.1.bias": [3072],
"controlnet_single_blocks.1.weight": [3072, 3072],
"controlnet_single_blocks.2.bias": [3072],
"controlnet_single_blocks.2.weight": [3072, 3072],
"controlnet_single_blocks.3.bias": [3072],
"controlnet_single_blocks.3.weight": [3072, 3072],
"controlnet_single_blocks.4.bias": [3072],
"controlnet_single_blocks.4.weight": [3072, 3072],
"controlnet_single_blocks.5.bias": [3072],
"controlnet_single_blocks.5.weight": [3072, 3072],
"controlnet_single_blocks.6.bias": [3072],
"controlnet_single_blocks.6.weight": [3072, 3072],
"controlnet_single_blocks.7.bias": [3072],
"controlnet_single_blocks.7.weight": [3072, 3072],
"controlnet_single_blocks.8.bias": [3072],
"controlnet_single_blocks.8.weight": [3072, 3072],
"controlnet_single_blocks.9.bias": [3072],
"controlnet_single_blocks.9.weight": [3072, 3072],
"controlnet_x_embedder.bias": [3072],
"controlnet_x_embedder.weight": [3072, 64],
"single_transformer_blocks.0.attn.norm_k.weight": [128],
"single_transformer_blocks.0.attn.norm_q.weight": [128],
"single_transformer_blocks.0.attn.to_k.bias": [3072],
"single_transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_q.bias": [3072],
"single_transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.0.attn.to_v.bias": [3072],
"single_transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.0.norm.linear.bias": [9216],
"single_transformer_blocks.0.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.0.proj_mlp.bias": [12288],
"single_transformer_blocks.0.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.0.proj_out.bias": [3072],
"single_transformer_blocks.0.proj_out.weight": [3072, 15360],
"single_transformer_blocks.1.attn.norm_k.weight": [128],
"single_transformer_blocks.1.attn.norm_q.weight": [128],
"single_transformer_blocks.1.attn.to_k.bias": [3072],
"single_transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_q.bias": [3072],
"single_transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.1.attn.to_v.bias": [3072],
"single_transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.1.norm.linear.bias": [9216],
"single_transformer_blocks.1.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.1.proj_mlp.bias": [12288],
"single_transformer_blocks.1.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.1.proj_out.bias": [3072],
"single_transformer_blocks.1.proj_out.weight": [3072, 15360],
"single_transformer_blocks.2.attn.norm_k.weight": [128],
"single_transformer_blocks.2.attn.norm_q.weight": [128],
"single_transformer_blocks.2.attn.to_k.bias": [3072],
"single_transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_q.bias": [3072],
"single_transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.2.attn.to_v.bias": [3072],
"single_transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.2.norm.linear.bias": [9216],
"single_transformer_blocks.2.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.2.proj_mlp.bias": [12288],
"single_transformer_blocks.2.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.2.proj_out.bias": [3072],
"single_transformer_blocks.2.proj_out.weight": [3072, 15360],
"single_transformer_blocks.3.attn.norm_k.weight": [128],
"single_transformer_blocks.3.attn.norm_q.weight": [128],
"single_transformer_blocks.3.attn.to_k.bias": [3072],
"single_transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_q.bias": [3072],
"single_transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.3.attn.to_v.bias": [3072],
"single_transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.3.norm.linear.bias": [9216],
"single_transformer_blocks.3.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.3.proj_mlp.bias": [12288],
"single_transformer_blocks.3.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.3.proj_out.bias": [3072],
"single_transformer_blocks.3.proj_out.weight": [3072, 15360],
"single_transformer_blocks.4.attn.norm_k.weight": [128],
"single_transformer_blocks.4.attn.norm_q.weight": [128],
"single_transformer_blocks.4.attn.to_k.bias": [3072],
"single_transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_q.bias": [3072],
"single_transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.4.attn.to_v.bias": [3072],
"single_transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.4.norm.linear.bias": [9216],
"single_transformer_blocks.4.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.4.proj_mlp.bias": [12288],
"single_transformer_blocks.4.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.4.proj_out.bias": [3072],
"single_transformer_blocks.4.proj_out.weight": [3072, 15360],
"single_transformer_blocks.5.attn.norm_k.weight": [128],
"single_transformer_blocks.5.attn.norm_q.weight": [128],
"single_transformer_blocks.5.attn.to_k.bias": [3072],
"single_transformer_blocks.5.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_q.bias": [3072],
"single_transformer_blocks.5.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.5.attn.to_v.bias": [3072],
"single_transformer_blocks.5.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.5.norm.linear.bias": [9216],
"single_transformer_blocks.5.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.5.proj_mlp.bias": [12288],
"single_transformer_blocks.5.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.5.proj_out.bias": [3072],
"single_transformer_blocks.5.proj_out.weight": [3072, 15360],
"single_transformer_blocks.6.attn.norm_k.weight": [128],
"single_transformer_blocks.6.attn.norm_q.weight": [128],
"single_transformer_blocks.6.attn.to_k.bias": [3072],
"single_transformer_blocks.6.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_q.bias": [3072],
"single_transformer_blocks.6.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.6.attn.to_v.bias": [3072],
"single_transformer_blocks.6.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.6.norm.linear.bias": [9216],
"single_transformer_blocks.6.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.6.proj_mlp.bias": [12288],
"single_transformer_blocks.6.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.6.proj_out.bias": [3072],
"single_transformer_blocks.6.proj_out.weight": [3072, 15360],
"single_transformer_blocks.7.attn.norm_k.weight": [128],
"single_transformer_blocks.7.attn.norm_q.weight": [128],
"single_transformer_blocks.7.attn.to_k.bias": [3072],
"single_transformer_blocks.7.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_q.bias": [3072],
"single_transformer_blocks.7.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.7.attn.to_v.bias": [3072],
"single_transformer_blocks.7.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.7.norm.linear.bias": [9216],
"single_transformer_blocks.7.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.7.proj_mlp.bias": [12288],
"single_transformer_blocks.7.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.7.proj_out.bias": [3072],
"single_transformer_blocks.7.proj_out.weight": [3072, 15360],
"single_transformer_blocks.8.attn.norm_k.weight": [128],
"single_transformer_blocks.8.attn.norm_q.weight": [128],
"single_transformer_blocks.8.attn.to_k.bias": [3072],
"single_transformer_blocks.8.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_q.bias": [3072],
"single_transformer_blocks.8.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.8.attn.to_v.bias": [3072],
"single_transformer_blocks.8.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.8.norm.linear.bias": [9216],
"single_transformer_blocks.8.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.8.proj_mlp.bias": [12288],
"single_transformer_blocks.8.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.8.proj_out.bias": [3072],
"single_transformer_blocks.8.proj_out.weight": [3072, 15360],
"single_transformer_blocks.9.attn.norm_k.weight": [128],
"single_transformer_blocks.9.attn.norm_q.weight": [128],
"single_transformer_blocks.9.attn.to_k.bias": [3072],
"single_transformer_blocks.9.attn.to_k.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_q.bias": [3072],
"single_transformer_blocks.9.attn.to_q.weight": [3072, 3072],
"single_transformer_blocks.9.attn.to_v.bias": [3072],
"single_transformer_blocks.9.attn.to_v.weight": [3072, 3072],
"single_transformer_blocks.9.norm.linear.bias": [9216],
"single_transformer_blocks.9.norm.linear.weight": [9216, 3072],
"single_transformer_blocks.9.proj_mlp.bias": [12288],
"single_transformer_blocks.9.proj_mlp.weight": [12288, 3072],
"single_transformer_blocks.9.proj_out.bias": [3072],
"single_transformer_blocks.9.proj_out.weight": [3072, 15360],
"time_text_embed.guidance_embedder.linear_1.bias": [3072],
"time_text_embed.guidance_embedder.linear_1.weight": [3072, 256],
"time_text_embed.guidance_embedder.linear_2.bias": [3072],
"time_text_embed.guidance_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.text_embedder.linear_1.bias": [3072],
"time_text_embed.text_embedder.linear_1.weight": [3072, 768],
"time_text_embed.text_embedder.linear_2.bias": [3072],
"time_text_embed.text_embedder.linear_2.weight": [3072, 3072],
"time_text_embed.timestep_embedder.linear_1.bias": [3072],
"time_text_embed.timestep_embedder.linear_1.weight": [3072, 256],
"time_text_embed.timestep_embedder.linear_2.bias": [3072],
"time_text_embed.timestep_embedder.linear_2.weight": [3072, 3072],
"transformer_blocks.0.attn.add_k_proj.bias": [3072],
"transformer_blocks.0.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_q_proj.bias": [3072],
"transformer_blocks.0.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.add_v_proj.bias": [3072],
"transformer_blocks.0.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.0.attn.norm_added_k.weight": [128],
"transformer_blocks.0.attn.norm_added_q.weight": [128],
"transformer_blocks.0.attn.norm_k.weight": [128],
"transformer_blocks.0.attn.norm_q.weight": [128],
"transformer_blocks.0.attn.to_add_out.bias": [3072],
"transformer_blocks.0.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.0.attn.to_k.bias": [3072],
"transformer_blocks.0.attn.to_k.weight": [3072, 3072],
"transformer_blocks.0.attn.to_out.0.bias": [3072],
"transformer_blocks.0.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.0.attn.to_q.bias": [3072],
"transformer_blocks.0.attn.to_q.weight": [3072, 3072],
"transformer_blocks.0.attn.to_v.bias": [3072],
"transformer_blocks.0.attn.to_v.weight": [3072, 3072],
"transformer_blocks.0.ff.net.0.proj.bias": [12288],
"transformer_blocks.0.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff.net.2.bias": [3072],
"transformer_blocks.0.ff.net.2.weight": [3072, 12288],
"transformer_blocks.0.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.0.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.0.ff_context.net.2.bias": [3072],
"transformer_blocks.0.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.0.norm1.linear.bias": [18432],
"transformer_blocks.0.norm1.linear.weight": [18432, 3072],
"transformer_blocks.0.norm1_context.linear.bias": [18432],
"transformer_blocks.0.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.1.attn.add_k_proj.bias": [3072],
"transformer_blocks.1.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_q_proj.bias": [3072],
"transformer_blocks.1.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.add_v_proj.bias": [3072],
"transformer_blocks.1.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.1.attn.norm_added_k.weight": [128],
"transformer_blocks.1.attn.norm_added_q.weight": [128],
"transformer_blocks.1.attn.norm_k.weight": [128],
"transformer_blocks.1.attn.norm_q.weight": [128],
"transformer_blocks.1.attn.to_add_out.bias": [3072],
"transformer_blocks.1.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.1.attn.to_k.bias": [3072],
"transformer_blocks.1.attn.to_k.weight": [3072, 3072],
"transformer_blocks.1.attn.to_out.0.bias": [3072],
"transformer_blocks.1.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.1.attn.to_q.bias": [3072],
"transformer_blocks.1.attn.to_q.weight": [3072, 3072],
"transformer_blocks.1.attn.to_v.bias": [3072],
"transformer_blocks.1.attn.to_v.weight": [3072, 3072],
"transformer_blocks.1.ff.net.0.proj.bias": [12288],
"transformer_blocks.1.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff.net.2.bias": [3072],
"transformer_blocks.1.ff.net.2.weight": [3072, 12288],
"transformer_blocks.1.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.1.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.1.ff_context.net.2.bias": [3072],
"transformer_blocks.1.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.1.norm1.linear.bias": [18432],
"transformer_blocks.1.norm1.linear.weight": [18432, 3072],
"transformer_blocks.1.norm1_context.linear.bias": [18432],
"transformer_blocks.1.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.2.attn.add_k_proj.bias": [3072],
"transformer_blocks.2.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_q_proj.bias": [3072],
"transformer_blocks.2.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.add_v_proj.bias": [3072],
"transformer_blocks.2.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.2.attn.norm_added_k.weight": [128],
"transformer_blocks.2.attn.norm_added_q.weight": [128],
"transformer_blocks.2.attn.norm_k.weight": [128],
"transformer_blocks.2.attn.norm_q.weight": [128],
"transformer_blocks.2.attn.to_add_out.bias": [3072],
"transformer_blocks.2.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.2.attn.to_k.bias": [3072],
"transformer_blocks.2.attn.to_k.weight": [3072, 3072],
"transformer_blocks.2.attn.to_out.0.bias": [3072],
"transformer_blocks.2.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.2.attn.to_q.bias": [3072],
"transformer_blocks.2.attn.to_q.weight": [3072, 3072],
"transformer_blocks.2.attn.to_v.bias": [3072],
"transformer_blocks.2.attn.to_v.weight": [3072, 3072],
"transformer_blocks.2.ff.net.0.proj.bias": [12288],
"transformer_blocks.2.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff.net.2.bias": [3072],
"transformer_blocks.2.ff.net.2.weight": [3072, 12288],
"transformer_blocks.2.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.2.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.2.ff_context.net.2.bias": [3072],
"transformer_blocks.2.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.2.norm1.linear.bias": [18432],
"transformer_blocks.2.norm1.linear.weight": [18432, 3072],
"transformer_blocks.2.norm1_context.linear.bias": [18432],
"transformer_blocks.2.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.3.attn.add_k_proj.bias": [3072],
"transformer_blocks.3.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_q_proj.bias": [3072],
"transformer_blocks.3.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.add_v_proj.bias": [3072],
"transformer_blocks.3.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.3.attn.norm_added_k.weight": [128],
"transformer_blocks.3.attn.norm_added_q.weight": [128],
"transformer_blocks.3.attn.norm_k.weight": [128],
"transformer_blocks.3.attn.norm_q.weight": [128],
"transformer_blocks.3.attn.to_add_out.bias": [3072],
"transformer_blocks.3.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.3.attn.to_k.bias": [3072],
"transformer_blocks.3.attn.to_k.weight": [3072, 3072],
"transformer_blocks.3.attn.to_out.0.bias": [3072],
"transformer_blocks.3.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.3.attn.to_q.bias": [3072],
"transformer_blocks.3.attn.to_q.weight": [3072, 3072],
"transformer_blocks.3.attn.to_v.bias": [3072],
"transformer_blocks.3.attn.to_v.weight": [3072, 3072],
"transformer_blocks.3.ff.net.0.proj.bias": [12288],
"transformer_blocks.3.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff.net.2.bias": [3072],
"transformer_blocks.3.ff.net.2.weight": [3072, 12288],
"transformer_blocks.3.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.3.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.3.ff_context.net.2.bias": [3072],
"transformer_blocks.3.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.3.norm1.linear.bias": [18432],
"transformer_blocks.3.norm1.linear.weight": [18432, 3072],
"transformer_blocks.3.norm1_context.linear.bias": [18432],
"transformer_blocks.3.norm1_context.linear.weight": [18432, 3072],
"transformer_blocks.4.attn.add_k_proj.bias": [3072],
"transformer_blocks.4.attn.add_k_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_q_proj.bias": [3072],
"transformer_blocks.4.attn.add_q_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.add_v_proj.bias": [3072],
"transformer_blocks.4.attn.add_v_proj.weight": [3072, 3072],
"transformer_blocks.4.attn.norm_added_k.weight": [128],
"transformer_blocks.4.attn.norm_added_q.weight": [128],
"transformer_blocks.4.attn.norm_k.weight": [128],
"transformer_blocks.4.attn.norm_q.weight": [128],
"transformer_blocks.4.attn.to_add_out.bias": [3072],
"transformer_blocks.4.attn.to_add_out.weight": [3072, 3072],
"transformer_blocks.4.attn.to_k.bias": [3072],
"transformer_blocks.4.attn.to_k.weight": [3072, 3072],
"transformer_blocks.4.attn.to_out.0.bias": [3072],
"transformer_blocks.4.attn.to_out.0.weight": [3072, 3072],
"transformer_blocks.4.attn.to_q.bias": [3072],
"transformer_blocks.4.attn.to_q.weight": [3072, 3072],
"transformer_blocks.4.attn.to_v.bias": [3072],
"transformer_blocks.4.attn.to_v.weight": [3072, 3072],
"transformer_blocks.4.ff.net.0.proj.bias": [12288],
"transformer_blocks.4.ff.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff.net.2.bias": [3072],
"transformer_blocks.4.ff.net.2.weight": [3072, 12288],
"transformer_blocks.4.ff_context.net.0.proj.bias": [12288],
"transformer_blocks.4.ff_context.net.0.proj.weight": [12288, 3072],
"transformer_blocks.4.ff_context.net.2.bias": [3072],
"transformer_blocks.4.ff_context.net.2.weight": [3072, 12288],
"transformer_blocks.4.norm1.linear.bias": [18432],
"transformer_blocks.4.norm1.linear.weight": [18432, 3072],
"transformer_blocks.4.norm1_context.linear.bias": [18432],
"transformer_blocks.4.norm1_context.linear.weight": [18432, 3072],
"x_embedder.bias": [3072],
"x_embedder.weight": [3072, 64],
}
# InstantX FLUX ControlNet config for unit tests.
# Copied from https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
instantx_config = {
"_class_name": "FluxControlNetModel",
"_diffusers_version": "0.30.0.dev0",
"_name_or_path": "/mnt/wangqixun/",
"attention_head_dim": 128,
"axes_dims_rope": [16, 56, 56],
"guidance_embeds": True,
"in_channels": 64,
"joint_attention_dim": 4096,
"num_attention_heads": 24,
"num_layers": 5,
"num_mode": 10,
"num_single_layers": 10,
"patch_size": 1,
"pooled_projection_dim": 768,
}

View File

@@ -0,0 +1,108 @@
import sys
import pytest
import torch
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
from invokeai.backend.flux.controlnet.state_dict_utils import (
convert_diffusers_instantx_state_dict_to_bfl_format,
infer_flux_params_from_state_dict,
infer_instantx_num_control_modes_from_state_dict,
is_state_dict_instantx_controlnet,
is_state_dict_xlabs_controlnet,
)
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_config, instantx_sd_shapes
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_sd_shapes
@pytest.mark.parametrize(
["sd_shapes", "expected"],
[
(xlabs_sd_shapes, True),
(instantx_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
sd = {k: None for k in sd_shapes}
assert is_state_dict_xlabs_controlnet(sd) == expected
@pytest.mark.parametrize(
["sd_keys", "expected"],
[
(instantx_sd_shapes, True),
(xlabs_sd_shapes, False),
(["foo"], False),
],
)
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
sd = {k: None for k in sd_keys}
assert is_state_dict_instantx_controlnet(sd) == expected
def test_convert_diffusers_instantx_state_dict_to_bfl_format():
"""Smoke test convert_diffusers_instantx_state_dict_to_bfl_format() to ensure that it handles all of the keys."""
sd = {k: torch.zeros(1) for k in instantx_sd_shapes}
bfl_sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
assert bfl_sd is not None
# TODO(ryand): Figure out why some tests in this file are failing on the MacOS CI runners. It seems to be related to
# using the meta device. I can't reproduce the issue on my local MacOS system.
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_flux_params_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
assert flux_params.in_channels == instantx_config["in_channels"]
assert flux_params.vec_in_dim == instantx_config["pooled_projection_dim"]
assert flux_params.context_in_dim == instantx_config["joint_attention_dim"]
assert flux_params.hidden_size // flux_params.num_heads == instantx_config["attention_head_dim"]
assert flux_params.num_heads == instantx_config["num_attention_heads"]
assert flux_params.mlp_ratio == 4
assert flux_params.depth == instantx_config["num_layers"]
assert flux_params.depth_single_blocks == instantx_config["num_single_layers"]
assert flux_params.axes_dim == instantx_config["axes_dims_rope"]
assert flux_params.theta == 10000
assert flux_params.qkv_bias
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_infer_instantx_num_control_modes_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
assert num_control_modes == instantx_config["num_mode"]
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
def test_load_instantx_from_state_dict():
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
with torch.device("meta"):
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
flux_params = infer_flux_params_from_state_dict(sd)
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
with torch.device("meta"):
model = InstantXControlNetFlux(flux_params, num_control_modes)
model_sd = model.state_dict()
assert set(model_sd.keys()) == set(sd.keys())
for key, tensor in model_sd.items():
assert isinstance(tensor, torch.Tensor)
assert tensor.shape == sd[key].shape

View File

@@ -0,0 +1,91 @@
# State dict keys and shapes for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
# These keys were extracted from:
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
xlabs_sd_shapes = {
"controlnet_blocks.0.bias": [3072],
"controlnet_blocks.0.weight": [3072, 3072],
"controlnet_blocks.1.bias": [3072],
"controlnet_blocks.1.weight": [3072, 3072],
"double_blocks.0.img_attn.norm.key_norm.scale": [128],
"double_blocks.0.img_attn.norm.query_norm.scale": [128],
"double_blocks.0.img_attn.proj.bias": [3072],
"double_blocks.0.img_attn.proj.weight": [3072, 3072],
"double_blocks.0.img_attn.qkv.bias": [9216],
"double_blocks.0.img_attn.qkv.weight": [9216, 3072],
"double_blocks.0.img_mlp.0.bias": [12288],
"double_blocks.0.img_mlp.0.weight": [12288, 3072],
"double_blocks.0.img_mlp.2.bias": [3072],
"double_blocks.0.img_mlp.2.weight": [3072, 12288],
"double_blocks.0.img_mod.lin.bias": [18432],
"double_blocks.0.img_mod.lin.weight": [18432, 3072],
"double_blocks.0.txt_attn.norm.key_norm.scale": [128],
"double_blocks.0.txt_attn.norm.query_norm.scale": [128],
"double_blocks.0.txt_attn.proj.bias": [3072],
"double_blocks.0.txt_attn.proj.weight": [3072, 3072],
"double_blocks.0.txt_attn.qkv.bias": [9216],
"double_blocks.0.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.0.txt_mlp.0.bias": [12288],
"double_blocks.0.txt_mlp.0.weight": [12288, 3072],
"double_blocks.0.txt_mlp.2.bias": [3072],
"double_blocks.0.txt_mlp.2.weight": [3072, 12288],
"double_blocks.0.txt_mod.lin.bias": [18432],
"double_blocks.0.txt_mod.lin.weight": [18432, 3072],
"double_blocks.1.img_attn.norm.key_norm.scale": [128],
"double_blocks.1.img_attn.norm.query_norm.scale": [128],
"double_blocks.1.img_attn.proj.bias": [3072],
"double_blocks.1.img_attn.proj.weight": [3072, 3072],
"double_blocks.1.img_attn.qkv.bias": [9216],
"double_blocks.1.img_attn.qkv.weight": [9216, 3072],
"double_blocks.1.img_mlp.0.bias": [12288],
"double_blocks.1.img_mlp.0.weight": [12288, 3072],
"double_blocks.1.img_mlp.2.bias": [3072],
"double_blocks.1.img_mlp.2.weight": [3072, 12288],
"double_blocks.1.img_mod.lin.bias": [18432],
"double_blocks.1.img_mod.lin.weight": [18432, 3072],
"double_blocks.1.txt_attn.norm.key_norm.scale": [128],
"double_blocks.1.txt_attn.norm.query_norm.scale": [128],
"double_blocks.1.txt_attn.proj.bias": [3072],
"double_blocks.1.txt_attn.proj.weight": [3072, 3072],
"double_blocks.1.txt_attn.qkv.bias": [9216],
"double_blocks.1.txt_attn.qkv.weight": [9216, 3072],
"double_blocks.1.txt_mlp.0.bias": [12288],
"double_blocks.1.txt_mlp.0.weight": [12288, 3072],
"double_blocks.1.txt_mlp.2.bias": [3072],
"double_blocks.1.txt_mlp.2.weight": [3072, 12288],
"double_blocks.1.txt_mod.lin.bias": [18432],
"double_blocks.1.txt_mod.lin.weight": [18432, 3072],
"guidance_in.in_layer.bias": [3072],
"guidance_in.in_layer.weight": [3072, 256],
"guidance_in.out_layer.bias": [3072],
"guidance_in.out_layer.weight": [3072, 3072],
"img_in.bias": [3072],
"img_in.weight": [3072, 64],
"input_hint_block.0.bias": [16],
"input_hint_block.0.weight": [16, 3, 3, 3],
"input_hint_block.10.bias": [16],
"input_hint_block.10.weight": [16, 16, 3, 3],
"input_hint_block.12.bias": [16],
"input_hint_block.12.weight": [16, 16, 3, 3],
"input_hint_block.14.bias": [16],
"input_hint_block.14.weight": [16, 16, 3, 3],
"input_hint_block.2.bias": [16],
"input_hint_block.2.weight": [16, 16, 3, 3],
"input_hint_block.4.bias": [16],
"input_hint_block.4.weight": [16, 16, 3, 3],
"input_hint_block.6.bias": [16],
"input_hint_block.6.weight": [16, 16, 3, 3],
"input_hint_block.8.bias": [16],
"input_hint_block.8.weight": [16, 16, 3, 3],
"pos_embed_input.bias": [3072],
"pos_embed_input.weight": [3072, 64],
"time_in.in_layer.bias": [3072],
"time_in.in_layer.weight": [3072, 256],
"time_in.out_layer.bias": [3072],
"time_in.out_layer.weight": [3072, 3072],
"txt_in.bias": [3072],
"txt_in.weight": [3072, 4096],
"vector_in.in_layer.bias": [3072],
"vector_in.in_layer.weight": [3072, 768],
"vector_in.out_layer.bias": [3072],
"vector_in.out_layer.weight": [3072, 3072],
}