mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Add support for FLUX ControlNet models (XLabs and InstantX) (#7070)
## Summary Add support for FLUX ControlNet models (XLabs and InstantX). ## QA Instructions - [x] SD1 and SDXL ControlNets, since the ModelLoaderRegistry calls were changed. - [x] Single Xlabs controlnet - [x] Single InstantX union controlnet - [x] Single InstantX controlnet - [x] Single Shakker Labs Union controlnet - [x] Multiple controlnets - [x] Weight, start, end params all work as expected - [x] Can be used with image-to-image and inpainting. - [x] Clear error message if no VAE is passed when using InstantX controlnet. - [x] Install InstantX ControlNet in diffusers format from HF repo (`InstantX/FLUX.1-dev-Controlnet-Union`) - [x] Test all FLUX ControlNet starter models ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
This commit is contained in:
@@ -192,6 +192,7 @@ class FieldDescriptions:
|
||||
freeu_s2 = 'Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to mitigate the "oversmoothing effect" in the enhanced denoising process.'
|
||||
freeu_b1 = "Scaling factor for stage 1 to amplify the contributions of backbone features."
|
||||
freeu_b2 = "Scaling factor for stage 2 to amplify the contributions of backbone features."
|
||||
instantx_control_mode = "The control mode for InstantX ControlNet union models. Ignored for other ControlNet models. The standard mapping is: canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6). Negative values will be treated as 'None'."
|
||||
|
||||
|
||||
class ImageField(BaseModel):
|
||||
|
||||
99
invokeai/app/invocations/flux_controlnet.py
Normal file
99
invokeai/app/invocations/flux_controlnet.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
Classification,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.fields import FieldDescriptions, ImageField, InputField, OutputField, UIType
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES
|
||||
|
||||
|
||||
class FluxControlNetField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ModelIdentifierField = Field(description="The ControlNet model to use")
|
||||
control_weight: float | list[float] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
instantx_control_mode: int | None = Field(default=-1, description=FieldDescriptions.instantx_control_mode)
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
|
||||
@invocation_output("flux_controlnet_output")
|
||||
class FluxControlNetOutput(BaseInvocationOutput):
|
||||
"""FLUX ControlNet info"""
|
||||
|
||||
control: FluxControlNetField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation(
|
||||
"flux_controlnet",
|
||||
title="FLUX ControlNet",
|
||||
tags=["controlnet", "flux"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxControlNetInvocation(BaseInvocation):
|
||||
"""Collect FLUX ControlNet info to pass to other nodes."""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ModelIdentifierField = InputField(
|
||||
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
|
||||
)
|
||||
control_weight: float | list[float] = InputField(
|
||||
default=1.0, ge=-1, le=2, description="The weight given to the ControlNet"
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
# Note: We default to -1 instead of None, because in the workflow editor UI None is not currently supported.
|
||||
instantx_control_mode: int | None = InputField(default=-1, description=FieldDescriptions.instantx_control_mode)
|
||||
|
||||
@field_validator("control_weight")
|
||||
@classmethod
|
||||
def validate_control_weight(cls, v: float | list[float]) -> float | list[float]:
|
||||
validate_weights(v)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_begin_end_step_percent(self):
|
||||
validate_begin_end_step(self.begin_step_percent, self.end_step_percent)
|
||||
return self
|
||||
|
||||
def invoke(self, context: InvocationContext) -> FluxControlNetOutput:
|
||||
return FluxControlNetOutput(
|
||||
control=FluxControlNetField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
resize_mode=self.resize_mode,
|
||||
instantx_control_mode=self.instantx_control_mode,
|
||||
),
|
||||
)
|
||||
@@ -16,11 +16,16 @@ from invokeai.app.invocations.fields import (
|
||||
WithBoard,
|
||||
WithMetadata,
|
||||
)
|
||||
from invokeai.app.invocations.model import TransformerField
|
||||
from invokeai.app.invocations.flux_controlnet import FluxControlNetField
|
||||
from invokeai.app.invocations.model import TransformerField, VAEField
|
||||
from invokeai.app.invocations.primitives import LatentsOutput
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.denoise import denoise
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.sampling_utils import (
|
||||
clip_timestep_schedule_fractional,
|
||||
@@ -44,7 +49,7 @@ from invokeai.backend.util.devices import TorchDevice
|
||||
title="FLUX Denoise",
|
||||
tags=["image", "flux"],
|
||||
category="image",
|
||||
version="3.0.0",
|
||||
version="3.1.0",
|
||||
classification=Classification.Prototype,
|
||||
)
|
||||
class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
@@ -87,6 +92,13 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
description="The guidance strength. Higher values adhere more strictly to the prompt, and will produce less diverse images. FLUX dev only, ignored for schnell.",
|
||||
)
|
||||
seed: int = InputField(default=0, description="Randomness seed for reproducibility.")
|
||||
control: FluxControlNetField | list[FluxControlNetField] | None = InputField(
|
||||
default=None, input=Input.Connection, description="ControlNet models."
|
||||
)
|
||||
controlnet_vae: VAEField | None = InputField(
|
||||
description=FieldDescriptions.vae,
|
||||
input=Input.Connection,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||
@@ -167,8 +179,8 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
inpaint_mask = self._prep_inpaint_mask(context, x)
|
||||
|
||||
b, _c, h, w = x.shape
|
||||
img_ids = generate_img_ids(h=h, w=w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
b, _c, latent_h, latent_w = x.shape
|
||||
img_ids = generate_img_ids(h=latent_h, w=latent_w, batch_size=b, device=x.device, dtype=x.dtype)
|
||||
|
||||
bs, t5_seq_len, _ = t5_embeddings.shape
|
||||
txt_ids = torch.zeros(bs, t5_seq_len, 3, dtype=inference_dtype, device=TorchDevice.choose_torch_device())
|
||||
@@ -192,12 +204,21 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
noise=noise,
|
||||
)
|
||||
|
||||
with (
|
||||
transformer_info.model_on_device() as (cached_weights, transformer),
|
||||
ExitStack() as exit_stack,
|
||||
):
|
||||
assert isinstance(transformer, Flux)
|
||||
with ExitStack() as exit_stack:
|
||||
# Prepare ControlNet extensions.
|
||||
# Note: We do this before loading the transformer model to minimize peak memory (see implementation).
|
||||
controlnet_extensions = self._prep_controlnet_extensions(
|
||||
context=context,
|
||||
exit_stack=exit_stack,
|
||||
latent_height=latent_h,
|
||||
latent_width=latent_w,
|
||||
dtype=inference_dtype,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
# Load the transformer model.
|
||||
(cached_weights, transformer) = exit_stack.enter_context(transformer_info.model_on_device())
|
||||
assert isinstance(transformer, Flux)
|
||||
config = transformer_info.config
|
||||
assert config is not None
|
||||
|
||||
@@ -242,6 +263,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
step_callback=self._build_step_callback(context),
|
||||
guidance=self.guidance,
|
||||
inpaint_extension=inpaint_extension,
|
||||
controlnet_extensions=controlnet_extensions,
|
||||
)
|
||||
|
||||
x = unpack(x.float(), self.height, self.width)
|
||||
@@ -288,6 +310,104 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# `latents`.
|
||||
return mask.expand_as(latents)
|
||||
|
||||
def _prep_controlnet_extensions(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
exit_stack: ExitStack,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> list[XLabsControlNetExtension | InstantXControlNetExtension]:
|
||||
# Normalize the controlnet input to list[ControlField].
|
||||
controlnets: list[FluxControlNetField]
|
||||
if self.control is None:
|
||||
controlnets = []
|
||||
elif isinstance(self.control, FluxControlNetField):
|
||||
controlnets = [self.control]
|
||||
elif isinstance(self.control, list):
|
||||
controlnets = self.control
|
||||
else:
|
||||
raise ValueError(f"Unsupported controlnet type: {type(self.control)}")
|
||||
|
||||
# TODO(ryand): Add a field to the model config so that we can distinguish between XLabs and InstantX ControlNets
|
||||
# before loading the models. Then make sure that all VAE encoding is done before loading the ControlNets to
|
||||
# minimize peak memory.
|
||||
|
||||
# First, load the ControlNet models so that we can determine the ControlNet types.
|
||||
controlnet_models = [context.models.load(controlnet.control_model) for controlnet in controlnets]
|
||||
|
||||
# Calculate the controlnet conditioning tensors.
|
||||
# We do this before loading the ControlNet models because it may require running the VAE, and we are trying to
|
||||
# keep peak memory down.
|
||||
controlnet_conds: list[torch.Tensor] = []
|
||||
for controlnet, controlnet_model in zip(controlnets, controlnet_models, strict=True):
|
||||
image = context.images.get_pil(controlnet.image.image_name)
|
||||
if isinstance(controlnet_model.model, InstantXControlNetFlux):
|
||||
if self.controlnet_vae is None:
|
||||
raise ValueError("A ControlNet VAE is required when using an InstantX FLUX ControlNet.")
|
||||
vae_info = context.models.load(self.controlnet_vae.vae)
|
||||
controlnet_conds.append(
|
||||
InstantXControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
vae_info=vae_info,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
)
|
||||
)
|
||||
elif isinstance(controlnet_model.model, XLabsControlNetFlux):
|
||||
controlnet_conds.append(
|
||||
XLabsControlNetExtension.prepare_controlnet_cond(
|
||||
controlnet_image=image,
|
||||
latent_height=latent_height,
|
||||
latent_width=latent_width,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
resize_mode=controlnet.resize_mode,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally, load the ControlNet models and initialize the ControlNet extensions.
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension] = []
|
||||
for controlnet, controlnet_cond, controlnet_model in zip(
|
||||
controlnets, controlnet_conds, controlnet_models, strict=True
|
||||
):
|
||||
model = exit_stack.enter_context(controlnet_model)
|
||||
|
||||
if isinstance(model, XLabsControlNetFlux):
|
||||
controlnet_extensions.append(
|
||||
XLabsControlNetExtension(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
elif isinstance(model, InstantXControlNetFlux):
|
||||
instantx_control_mode: torch.Tensor | None = None
|
||||
if controlnet.instantx_control_mode is not None and controlnet.instantx_control_mode >= 0:
|
||||
instantx_control_mode = torch.tensor(controlnet.instantx_control_mode, dtype=torch.long)
|
||||
instantx_control_mode = instantx_control_mode.reshape([-1, 1])
|
||||
|
||||
controlnet_extensions.append(
|
||||
InstantXControlNetExtension(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
instantx_control_mode=instantx_control_mode,
|
||||
weight=controlnet.control_weight,
|
||||
begin_step_percent=controlnet.begin_step_percent,
|
||||
end_step_percent=controlnet.end_step_percent,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ControlNet model type: {type(model)}")
|
||||
|
||||
return controlnet_extensions
|
||||
|
||||
def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]:
|
||||
for lora in self.transformer.loras:
|
||||
lora_info = context.models.load(lora.lora)
|
||||
|
||||
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
0
invokeai/backend/flux/controlnet/__init__.py
Normal file
58
invokeai/backend/flux/controlnet/controlnet_flux_output.py
Normal file
58
invokeai/backend/flux/controlnet/controlnet_flux_output.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class ControlNetFluxOutput:
|
||||
single_block_residuals: list[torch.Tensor] | None
|
||||
double_block_residuals: list[torch.Tensor] | None
|
||||
|
||||
def apply_weight(self, weight: float):
|
||||
if self.single_block_residuals is not None:
|
||||
for i in range(len(self.single_block_residuals)):
|
||||
self.single_block_residuals[i] = self.single_block_residuals[i] * weight
|
||||
if self.double_block_residuals is not None:
|
||||
for i in range(len(self.double_block_residuals)):
|
||||
self.double_block_residuals[i] = self.double_block_residuals[i] * weight
|
||||
|
||||
|
||||
def add_tensor_lists_elementwise(
|
||||
list1: list[torch.Tensor] | None, list2: list[torch.Tensor] | None
|
||||
) -> list[torch.Tensor] | None:
|
||||
"""Add two tensor lists elementwise that could be None."""
|
||||
if list1 is None and list2 is None:
|
||||
return None
|
||||
if list1 is None:
|
||||
return list2
|
||||
if list2 is None:
|
||||
return list1
|
||||
|
||||
new_list: list[torch.Tensor] = []
|
||||
for list1_tensor, list2_tensor in zip(list1, list2, strict=True):
|
||||
new_list.append(list1_tensor + list2_tensor)
|
||||
return new_list
|
||||
|
||||
|
||||
def add_controlnet_flux_outputs(
|
||||
controlnet_output_1: ControlNetFluxOutput, controlnet_output_2: ControlNetFluxOutput
|
||||
) -> ControlNetFluxOutput:
|
||||
return ControlNetFluxOutput(
|
||||
single_block_residuals=add_tensor_lists_elementwise(
|
||||
controlnet_output_1.single_block_residuals, controlnet_output_2.single_block_residuals
|
||||
),
|
||||
double_block_residuals=add_tensor_lists_elementwise(
|
||||
controlnet_output_1.double_block_residuals, controlnet_output_2.double_block_residuals
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def sum_controlnet_flux_outputs(
|
||||
controlnet_outputs: list[ControlNetFluxOutput],
|
||||
) -> ControlNetFluxOutput:
|
||||
controlnet_output_sum = ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
for controlnet_output in controlnet_outputs:
|
||||
controlnet_output_sum = add_controlnet_flux_outputs(controlnet_output_sum, controlnet_output)
|
||||
|
||||
return controlnet_output_sum
|
||||
180
invokeai/backend/flux/controlnet/instantx_controlnet_flux.py
Normal file
180
invokeai/backend/flux/controlnet/instantx_controlnet_flux.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# This file was initially copied from:
|
||||
# https://github.com/huggingface/diffusers/blob/99f608218caa069a2f16dcf9efab46959b15aec0/src/diffusers/models/controlnet_flux.py
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from invokeai.backend.flux.controlnet.zero_module import zero_module
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstantXControlNetFluxOutput:
|
||||
controlnet_block_samples: list[torch.Tensor] | None
|
||||
controlnet_single_block_samples: list[torch.Tensor] | None
|
||||
|
||||
|
||||
# NOTE(ryand): Mapping between diffusers FLUX transformer params and BFL FLUX transformer params:
|
||||
# - Diffusers: BFL
|
||||
# - in_channels: in_channels
|
||||
# - num_layers: depth
|
||||
# - num_single_layers: depth_single_blocks
|
||||
# - attention_head_dim: hidden_size // num_heads
|
||||
# - num_attention_heads: num_heads
|
||||
# - joint_attention_dim: context_in_dim
|
||||
# - pooled_projection_dim: vec_in_dim
|
||||
# - guidance_embeds: guidance_embed
|
||||
# - axes_dims_rope: axes_dim
|
||||
|
||||
|
||||
class InstantXControlNetFlux(torch.nn.Module):
|
||||
def __init__(self, params: FluxParams, num_control_modes: int | None = None):
|
||||
"""
|
||||
Args:
|
||||
params (FluxParams): The parameters for the FLUX model.
|
||||
num_control_modes (int | None, optional): The number of controlnet modes. If non-None, then the model is a
|
||||
'union controlnet' model and expects a mode conditioning input at runtime.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# The following modules mirror the base FLUX transformer model.
|
||||
# -------------------------------------------------------------
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
# The following modules are specific to the ControlNet model.
|
||||
# -----------------------------------------------------------
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.double_blocks)):
|
||||
self.controlnet_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.controlnet_single_blocks = nn.ModuleList([])
|
||||
for _ in range(len(self.single_blocks)):
|
||||
self.controlnet_single_blocks.append(zero_module(nn.Linear(self.hidden_size, self.hidden_size)))
|
||||
|
||||
self.is_union = False
|
||||
if num_control_modes is not None:
|
||||
self.is_union = True
|
||||
self.controlnet_mode_embedder = nn.Embedding(num_control_modes, self.hidden_size)
|
||||
|
||||
self.controlnet_x_embedder = zero_module(torch.nn.Linear(self.in_channels, self.hidden_size))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
controlnet_cond: torch.Tensor,
|
||||
controlnet_mode: torch.Tensor | None,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
guidance: torch.Tensor | None = None,
|
||||
) -> InstantXControlNetFluxOutput:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
img = self.img_in(img)
|
||||
|
||||
# Add controlnet_cond embedding.
|
||||
img = img + self.controlnet_x_embedder(controlnet_cond)
|
||||
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
# If this is a union ControlNet, then concat the control mode embedding to the T5 text embedding.
|
||||
if self.is_union:
|
||||
if controlnet_mode is None:
|
||||
# We allow users to enter 'None' as the controlnet_mode if they don't want to worry about this input.
|
||||
# We've chosen to use a zero-embedding in this case.
|
||||
zero_index = torch.zeros([1, 1], dtype=torch.long, device=txt.device)
|
||||
controlnet_mode_emb = torch.zeros_like(self.controlnet_mode_embedder(zero_index))
|
||||
else:
|
||||
controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
|
||||
txt = torch.cat([controlnet_mode_emb, txt], dim=1)
|
||||
txt_ids = torch.cat([txt_ids[:, :1, :], txt_ids], dim=1)
|
||||
else:
|
||||
assert controlnet_mode is None
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
double_block_samples: list[torch.Tensor] = []
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
double_block_samples.append(img)
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
single_block_samples: list[torch.Tensor] = []
|
||||
for block in self.single_blocks:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
single_block_samples.append(img[:, txt.shape[1] :])
|
||||
|
||||
# ControlNet Block
|
||||
controlnet_double_block_samples: list[torch.Tensor] = []
|
||||
for double_block_sample, controlnet_block in zip(double_block_samples, self.controlnet_blocks, strict=True):
|
||||
double_block_sample = controlnet_block(double_block_sample)
|
||||
controlnet_double_block_samples.append(double_block_sample)
|
||||
|
||||
controlnet_single_block_samples: list[torch.Tensor] = []
|
||||
for single_block_sample, controlnet_block in zip(
|
||||
single_block_samples, self.controlnet_single_blocks, strict=True
|
||||
):
|
||||
single_block_sample = controlnet_block(single_block_sample)
|
||||
controlnet_single_block_samples.append(single_block_sample)
|
||||
|
||||
return InstantXControlNetFluxOutput(
|
||||
controlnet_block_samples=controlnet_double_block_samples or None,
|
||||
controlnet_single_block_samples=controlnet_single_block_samples or None,
|
||||
)
|
||||
295
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
295
invokeai/backend/flux/controlnet/state_dict_utils.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
|
||||
|
||||
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an XLabs ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"input_hint_block.0.bias",
|
||||
"input_hint_block.0.weight",
|
||||
"pos_embed_input.bias",
|
||||
"pos_embed_input.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an InstantX ControlNet model?
|
||||
|
||||
This is intended to be a reasonably high-precision detector, but it is not guaranteed to have perfect precision.
|
||||
"""
|
||||
# If all of the expected keys are present, then this is very likely an InstantX ControlNet model.
|
||||
expected_keys = {
|
||||
"controlnet_blocks.0.bias",
|
||||
"controlnet_blocks.0.weight",
|
||||
"controlnet_x_embedder.bias",
|
||||
"controlnet_x_embedder.weight",
|
||||
}
|
||||
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
|
||||
"""Fuse weights along dimension 0.
|
||||
|
||||
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
|
||||
"""
|
||||
# TODO(ryand): Double check dim=0 is correct.
|
||||
return torch.cat(t, dim=0)
|
||||
|
||||
|
||||
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], double_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
||||
to_prefix = f"double_blocks.{double_block_index}"
|
||||
from_prefix = f"transformer_blocks.{double_block_index}"
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# txt_attn.qkv
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
|
||||
)
|
||||
|
||||
# img_attn.qkv
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# img_attn
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
# img_mlp
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
# img_mod
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
# txt_attn
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
# txt_mlp
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
# txt_mod
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], single_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a single block from diffusers format to BFL format."""
|
||||
to_prefix = f"single_blocks.{single_block_index}"
|
||||
from_prefix = f"single_transformer_blocks.{single_block_index}"
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.to_q.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# linear1 (qkv)
|
||||
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# linear2
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
# modulation
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
# norm
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
|
||||
InstantXControlNetFlux model.
|
||||
|
||||
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
|
||||
implementation to InstantXControlNetFlux to make it compatible with BFL-style models. This function converts the
|
||||
original state dict to the format expected by InstantXControlNetFlux.
|
||||
"""
|
||||
# Shallow copy sd so that we can pop keys from it without modifying the original.
|
||||
sd = sd.copy()
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
basic_key_map = {
|
||||
# Base model keys.
|
||||
# ----------------
|
||||
# txt_in keys.
|
||||
"context_embedder.bias": "txt_in.bias",
|
||||
"context_embedder.weight": "txt_in.weight",
|
||||
# guidance_in MLPEmbedder keys.
|
||||
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
|
||||
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
|
||||
# vector_in MLPEmbedder keys.
|
||||
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
|
||||
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
|
||||
# time_in MLPEmbedder keys.
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
|
||||
# img_in keys.
|
||||
"x_embedder.bias": "img_in.bias",
|
||||
"x_embedder.weight": "img_in.weight",
|
||||
}
|
||||
for old_key, new_key in basic_key_map.items():
|
||||
v = sd.pop(old_key, None)
|
||||
if v is not None:
|
||||
new_sd[new_key] = v
|
||||
|
||||
# Handle the double_blocks.
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_double_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_double_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Handle the single_blocks.
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_singe_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_singe_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Transfer controlnet keys as-is.
|
||||
for k in list(sd.keys()):
|
||||
if k.startswith("controlnet_"):
|
||||
new_sd[k] = sd.pop(k)
|
||||
|
||||
# Assert that all keys have been handled.
|
||||
assert len(sd) == 0
|
||||
return new_sd
|
||||
|
||||
|
||||
def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams:
|
||||
"""Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this
|
||||
information is all contained in the config.json file that accompanies the model. However, being apple to infer the
|
||||
params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file.
|
||||
"""
|
||||
hidden_size = sd["img_in.weight"].shape[0]
|
||||
mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0]
|
||||
# mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible float precision
|
||||
# issues. In practice, mlp_ratio is usually 4.
|
||||
mlp_ratio = mlp_hidden_dim // hidden_size
|
||||
|
||||
head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0]
|
||||
num_heads = hidden_size // head_dim
|
||||
|
||||
# Count the number of double blocks.
|
||||
double_block_index = 0
|
||||
while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd:
|
||||
double_block_index += 1
|
||||
|
||||
# Count the number of single blocks.
|
||||
single_block_index = 0
|
||||
while f"single_blocks.{single_block_index}.linear1.weight" in sd:
|
||||
single_block_index += 1
|
||||
|
||||
return FluxParams(
|
||||
in_channels=sd["img_in.weight"].shape[1],
|
||||
vec_in_dim=sd["vector_in.in_layer.weight"].shape[1],
|
||||
context_in_dim=sd["txt_in.weight"].shape[1],
|
||||
hidden_size=hidden_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
num_heads=num_heads,
|
||||
depth=double_block_index,
|
||||
depth_single_blocks=single_block_index,
|
||||
# axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
|
||||
axes_dim=[16, 56, 56],
|
||||
# theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models.
|
||||
theta=10_000,
|
||||
qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd,
|
||||
guidance_embed="guidance_in.in_layer.weight" in sd,
|
||||
)
|
||||
|
||||
|
||||
def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None:
|
||||
"""Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict.
|
||||
|
||||
Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes.
|
||||
"""
|
||||
mode_embedder_key = "controlnet_mode_embedder.weight"
|
||||
if mode_embedder_key not in sd:
|
||||
return None
|
||||
|
||||
return sd[mode_embedder_key].shape[0]
|
||||
130
invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py
Normal file
130
invokeai/backend/flux/controlnet/xlabs_controlnet_flux.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# This file was initially based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/controlnet.py
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from invokeai.backend.flux.controlnet.zero_module import zero_module
|
||||
from invokeai.backend.flux.model import FluxParams
|
||||
from invokeai.backend.flux.modules.layers import DoubleStreamBlock, EmbedND, MLPEmbedder, timestep_embedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class XLabsControlNetFluxOutput:
|
||||
controlnet_double_block_residuals: list[torch.Tensor] | None
|
||||
|
||||
|
||||
class XLabsControlNetFlux(torch.nn.Module):
|
||||
"""A ControlNet model for FLUX.
|
||||
|
||||
The architecture is very similar to the base FLUX model, with the following differences:
|
||||
- A `controlnet_depth` parameter is passed to control the number of double_blocks that the ControlNet is applied to.
|
||||
In order to keep the ControlNet small, this is typically much less than the depth of the base FLUX model.
|
||||
- There is a set of `controlnet_blocks` that are applied to the output of each double_block.
|
||||
"""
|
||||
|
||||
def __init__(self, params: FluxParams, controlnet_depth: int = 2):
|
||||
super().__init__()
|
||||
|
||||
self.params = params
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = self.in_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else torch.nn.Identity()
|
||||
)
|
||||
self.txt_in = torch.nn.Linear(params.context_in_dim, self.hidden_size)
|
||||
|
||||
self.double_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
)
|
||||
for _ in range(controlnet_depth)
|
||||
]
|
||||
)
|
||||
|
||||
# Add ControlNet blocks.
|
||||
self.controlnet_blocks = torch.nn.ModuleList([])
|
||||
for _ in range(controlnet_depth):
|
||||
controlnet_block = torch.nn.Linear(self.hidden_size, self.hidden_size)
|
||||
controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = torch.nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
||||
self.input_hint_block = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(3, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Conv2d(16, 16, 3, padding=1, stride=2),
|
||||
torch.nn.SiLU(),
|
||||
zero_module(torch.nn.Conv2d(16, 16, 3, padding=1)),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
controlnet_cond: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
guidance: torch.Tensor | None = None,
|
||||
) -> XLabsControlNetFluxOutput:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples: list[torch.Tensor] = []
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples.append(img)
|
||||
|
||||
controlnet_block_res_samples: list[torch.Tensor] = []
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks, strict=True):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples.append(block_res_sample)
|
||||
|
||||
return XLabsControlNetFluxOutput(controlnet_double_block_residuals=controlnet_block_res_samples)
|
||||
12
invokeai/backend/flux/controlnet/zero_module.py
Normal file
12
invokeai/backend/flux/controlnet/zero_module.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
T = TypeVar("T", bound=torch.nn.Module)
|
||||
|
||||
|
||||
def zero_module(module: T) -> T:
|
||||
"""Initialize the parameters of a module to zero."""
|
||||
for p in module.parameters():
|
||||
torch.nn.init.zeros_(p)
|
||||
return module
|
||||
@@ -3,7 +3,10 @@ from typing import Callable
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.flux.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput, sum_controlnet_flux_outputs
|
||||
from invokeai.backend.flux.extensions.inpaint_extension import InpaintExtension
|
||||
from invokeai.backend.flux.extensions.instantx_controlnet_extension import InstantXControlNetExtension
|
||||
from invokeai.backend.flux.extensions.xlabs_controlnet_extension import XLabsControlNetExtension
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
|
||||
@@ -21,6 +24,7 @@ def denoise(
|
||||
step_callback: Callable[[PipelineIntermediateState], None],
|
||||
guidance: float,
|
||||
inpaint_extension: InpaintExtension | None,
|
||||
controlnet_extensions: list[XLabsControlNetExtension | InstantXControlNetExtension],
|
||||
):
|
||||
# step 0 is the initial state
|
||||
total_steps = len(timesteps) - 1
|
||||
@@ -38,6 +42,30 @@ def denoise(
|
||||
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)
|
||||
for t_curr, t_prev in tqdm(list(zip(timesteps[:-1], timesteps[1:], strict=True))):
|
||||
t_vec = torch.full((img.shape[0],), t_curr, dtype=img.dtype, device=img.device)
|
||||
|
||||
# Run ControlNet models.
|
||||
controlnet_residuals: list[ControlNetFluxOutput] = []
|
||||
for controlnet_extension in controlnet_extensions:
|
||||
controlnet_residuals.append(
|
||||
controlnet_extension.run_controlnet(
|
||||
timestep_index=step - 1,
|
||||
total_num_timesteps=total_steps,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
)
|
||||
)
|
||||
|
||||
# Merge the ControlNet residuals from multiple ControlNets.
|
||||
# TODO(ryand): We may want to alculate the sum just-in-time to keep peak memory low. Keep in mind, that the
|
||||
# controlnet_residuals datastructure is efficient in that it likely contains multiple references to the same
|
||||
# tensors. Calculating the sum materializes each tensor into its own instance.
|
||||
merged_controlnet_residuals = sum_controlnet_flux_outputs(controlnet_residuals)
|
||||
|
||||
pred = model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
@@ -46,6 +74,8 @@ def denoise(
|
||||
y=vec,
|
||||
timesteps=t_vec,
|
||||
guidance=guidance_vec,
|
||||
controlnet_double_block_residuals=merged_controlnet_residuals.double_block_residuals,
|
||||
controlnet_single_block_residuals=merged_controlnet_residuals.single_block_residuals,
|
||||
)
|
||||
|
||||
preview_img = img - t_curr * pred
|
||||
|
||||
0
invokeai/backend/flux/extensions/__init__.py
Normal file
0
invokeai/backend/flux/extensions/__init__.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
|
||||
|
||||
class BaseControlNetExtension(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
self._weight = weight
|
||||
self._begin_step_percent = begin_step_percent
|
||||
self._end_step_percent = end_step_percent
|
||||
|
||||
def _get_weight(self, timestep_index: int, total_num_timesteps: int) -> float:
|
||||
first_step = math.floor(self._begin_step_percent * total_num_timesteps)
|
||||
last_step = math.ceil(self._end_step_percent * total_num_timesteps)
|
||||
|
||||
if timestep_index < first_step or timestep_index > last_step:
|
||||
return 0.0
|
||||
|
||||
if isinstance(self._weight, list):
|
||||
return self._weight[timestep_index]
|
||||
|
||||
return self._weight
|
||||
|
||||
@abstractmethod
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput: ...
|
||||
@@ -0,0 +1,194 @@
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.invocations.flux_vae_encode import FluxVaeEncodeInvocation
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import (
|
||||
InstantXControlNetFlux,
|
||||
InstantXControlNetFluxOutput,
|
||||
)
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
from invokeai.backend.flux.sampling_utils import pack
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
|
||||
|
||||
class InstantXControlNetExtension(BaseControlNetExtension):
|
||||
def __init__(
|
||||
self,
|
||||
model: InstantXControlNetFlux,
|
||||
controlnet_cond: torch.Tensor,
|
||||
instantx_control_mode: torch.Tensor | None,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
super().__init__(
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
self._model = model
|
||||
# The VAE-encoded and 'packed' control image to pass to the ControlNet model.
|
||||
self._controlnet_cond = controlnet_cond
|
||||
# TODO(ryand): Should we define an enum for the instantx_control_mode? Is it likely to change for future models?
|
||||
# The control mode for InstantX ControlNet union models.
|
||||
# See the values defined here: https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union#control-mode
|
||||
# Expected shape: (batch_size, 1), Expected dtype: torch.long
|
||||
# If None, a zero-embedding will be used.
|
||||
self._instantx_control_mode = instantx_control_mode
|
||||
|
||||
# TODO(ryand): Pass in these params if a new base transformer / InstantX ControlNet pair get released.
|
||||
self._flux_transformer_num_double_blocks = 19
|
||||
self._flux_transformer_num_single_blocks = 38
|
||||
|
||||
@classmethod
|
||||
def prepare_controlnet_cond(
|
||||
cls,
|
||||
controlnet_image: Image,
|
||||
vae_info: LoadedModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
resized_controlnet_image = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Shift the image from [0, 1] to [-1, 1].
|
||||
resized_controlnet_image = resized_controlnet_image * 2 - 1
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return controlnet_cond
|
||||
|
||||
@classmethod
|
||||
def from_controlnet_image(
|
||||
cls,
|
||||
model: InstantXControlNetFlux,
|
||||
controlnet_image: Image,
|
||||
instantx_control_mode: torch.Tensor | None,
|
||||
vae_info: LoadedModel,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
resized_controlnet_image = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Shift the image from [0, 1] to [-1, 1].
|
||||
resized_controlnet_image = resized_controlnet_image * 2 - 1
|
||||
|
||||
# Run VAE encoder.
|
||||
controlnet_cond = FluxVaeEncodeInvocation.vae_encode(vae_info=vae_info, image_tensor=resized_controlnet_image)
|
||||
controlnet_cond = pack(controlnet_cond)
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
instantx_control_mode=instantx_control_mode,
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
def _instantx_output_to_controlnet_output(
|
||||
self, instantx_output: InstantXControlNetFluxOutput
|
||||
) -> ControlNetFluxOutput:
|
||||
# The `interval_control` logic here is based on
|
||||
# https://github.com/huggingface/diffusers/blob/31058cdaef63ca660a1a045281d156239fba8192/src/diffusers/models/transformers/transformer_flux.py#L507-L511
|
||||
|
||||
# Handle double block residuals.
|
||||
double_block_residuals: list[torch.Tensor] = []
|
||||
double_block_samples = instantx_output.controlnet_block_samples
|
||||
if double_block_samples:
|
||||
interval_control = self._flux_transformer_num_double_blocks / len(double_block_samples)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
for i in range(self._flux_transformer_num_double_blocks):
|
||||
double_block_residuals.append(double_block_samples[i // interval_control])
|
||||
|
||||
# Handle single block residuals.
|
||||
single_block_residuals: list[torch.Tensor] = []
|
||||
single_block_samples = instantx_output.controlnet_single_block_samples
|
||||
if single_block_samples:
|
||||
interval_control = self._flux_transformer_num_single_blocks / len(single_block_samples)
|
||||
interval_control = int(math.ceil(interval_control))
|
||||
for i in range(self._flux_transformer_num_single_blocks):
|
||||
single_block_residuals.append(single_block_samples[i // interval_control])
|
||||
|
||||
return ControlNetFluxOutput(
|
||||
double_block_residuals=double_block_residuals or None,
|
||||
single_block_residuals=single_block_residuals or None,
|
||||
)
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput:
|
||||
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
|
||||
if weight < 1e-6:
|
||||
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
# Make sure inputs have correct device and dtype.
|
||||
self._controlnet_cond = self._controlnet_cond.to(device=img.device, dtype=img.dtype)
|
||||
self._instantx_control_mode = (
|
||||
self._instantx_control_mode.to(device=img.device) if self._instantx_control_mode is not None else None
|
||||
)
|
||||
|
||||
instantx_output: InstantXControlNetFluxOutput = self._model(
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
controlnet_mode=self._instantx_control_mode,
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
controlnet_output = self._instantx_output_to_controlnet_output(instantx_output)
|
||||
controlnet_output.apply_weight(weight)
|
||||
return controlnet_output
|
||||
150
invokeai/backend/flux/extensions/xlabs_controlnet_extension.py
Normal file
150
invokeai/backend/flux/extensions/xlabs_controlnet_extension.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
|
||||
from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_RESIZE_VALUES, prepare_control_image
|
||||
from invokeai.backend.flux.controlnet.controlnet_flux_output import ControlNetFluxOutput
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux, XLabsControlNetFluxOutput
|
||||
from invokeai.backend.flux.extensions.base_controlnet_extension import BaseControlNetExtension
|
||||
|
||||
|
||||
class XLabsControlNetExtension(BaseControlNetExtension):
|
||||
def __init__(
|
||||
self,
|
||||
model: XLabsControlNetFlux,
|
||||
controlnet_cond: torch.Tensor,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
super().__init__(
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
self._model = model
|
||||
# _controlnet_cond is the control image passed to the ControlNet model.
|
||||
# Pixel values are in the range [-1, 1]. Shape: (batch_size, 3, height, width).
|
||||
self._controlnet_cond = controlnet_cond
|
||||
|
||||
# TODO(ryand): Pass in these params if a new base transformer / XLabs ControlNet pair get released.
|
||||
self._flux_transformer_num_double_blocks = 19
|
||||
self._flux_transformer_num_single_blocks = 38
|
||||
|
||||
@classmethod
|
||||
def prepare_controlnet_cond(
|
||||
cls,
|
||||
controlnet_image: Image,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_cond = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return controlnet_cond
|
||||
|
||||
@classmethod
|
||||
def from_controlnet_image(
|
||||
cls,
|
||||
model: XLabsControlNetFlux,
|
||||
controlnet_image: Image,
|
||||
latent_height: int,
|
||||
latent_width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES,
|
||||
weight: Union[float, List[float]],
|
||||
begin_step_percent: float,
|
||||
end_step_percent: float,
|
||||
):
|
||||
image_height = latent_height * LATENT_SCALE_FACTOR
|
||||
image_width = latent_width * LATENT_SCALE_FACTOR
|
||||
|
||||
controlnet_cond = prepare_control_image(
|
||||
image=controlnet_image,
|
||||
do_classifier_free_guidance=False,
|
||||
width=image_width,
|
||||
height=image_height,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
control_mode="balanced",
|
||||
resize_mode=resize_mode,
|
||||
)
|
||||
|
||||
# Map pixel values from [0, 1] to [-1, 1].
|
||||
controlnet_cond = controlnet_cond * 2 - 1
|
||||
|
||||
return cls(
|
||||
model=model,
|
||||
controlnet_cond=controlnet_cond,
|
||||
weight=weight,
|
||||
begin_step_percent=begin_step_percent,
|
||||
end_step_percent=end_step_percent,
|
||||
)
|
||||
|
||||
def _xlabs_output_to_controlnet_output(self, xlabs_output: XLabsControlNetFluxOutput) -> ControlNetFluxOutput:
|
||||
# The modulo index logic used here is based on:
|
||||
# https://github.com/XLabs-AI/x-flux/blob/47495425dbed499be1e8e5a6e52628b07349cba2/src/flux/model.py#L198-L200
|
||||
|
||||
# Handle double block residuals.
|
||||
double_block_residuals: list[torch.Tensor] = []
|
||||
xlabs_double_block_residuals = xlabs_output.controlnet_double_block_residuals
|
||||
if xlabs_double_block_residuals is not None:
|
||||
for i in range(self._flux_transformer_num_double_blocks):
|
||||
double_block_residuals.append(xlabs_double_block_residuals[i % len(xlabs_double_block_residuals)])
|
||||
|
||||
return ControlNetFluxOutput(
|
||||
double_block_residuals=double_block_residuals,
|
||||
single_block_residuals=None,
|
||||
)
|
||||
|
||||
def run_controlnet(
|
||||
self,
|
||||
timestep_index: int,
|
||||
total_num_timesteps: int,
|
||||
img: torch.Tensor,
|
||||
img_ids: torch.Tensor,
|
||||
txt: torch.Tensor,
|
||||
txt_ids: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
timesteps: torch.Tensor,
|
||||
guidance: torch.Tensor | None,
|
||||
) -> ControlNetFluxOutput:
|
||||
weight = self._get_weight(timestep_index=timestep_index, total_num_timesteps=total_num_timesteps)
|
||||
if weight < 1e-6:
|
||||
return ControlNetFluxOutput(single_block_residuals=None, double_block_residuals=None)
|
||||
|
||||
xlabs_output: XLabsControlNetFluxOutput = self._model(
|
||||
img=img,
|
||||
img_ids=img_ids,
|
||||
controlnet_cond=self._controlnet_cond,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
timesteps=timesteps,
|
||||
y=y,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
controlnet_output = self._xlabs_output_to_controlnet_output(xlabs_output)
|
||||
controlnet_output.apply_weight(weight)
|
||||
return controlnet_output
|
||||
@@ -87,7 +87,9 @@ class Flux(nn.Module):
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor | None = None,
|
||||
guidance: Tensor | None,
|
||||
controlnet_double_block_residuals: list[Tensor] | None,
|
||||
controlnet_single_block_residuals: list[Tensor] | None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -105,12 +107,27 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
# Validate double_block_residuals shape.
|
||||
if controlnet_double_block_residuals is not None:
|
||||
assert len(controlnet_double_block_residuals) == len(self.double_blocks)
|
||||
for block_index, block in enumerate(self.double_blocks):
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if controlnet_double_block_residuals is not None:
|
||||
img += controlnet_double_block_residuals[block_index]
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
|
||||
# Validate single_block_residuals shape.
|
||||
if controlnet_single_block_residuals is not None:
|
||||
assert len(controlnet_single_block_residuals) == len(self.single_blocks)
|
||||
|
||||
for block_index, block in enumerate(self.single_blocks):
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
|
||||
if controlnet_single_block_residuals is not None:
|
||||
img[:, txt.shape[1] :, ...] += controlnet_single_block_residuals[block_index]
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
@@ -8,17 +8,36 @@ from diffusers import ControlNetModel
|
||||
from invokeai.backend.model_manager import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import (
|
||||
BaseModelType,
|
||||
ControlNetCheckpointConfig,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import ControlNetCheckpointConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion1, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusion2, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Diffusers
|
||||
)
|
||||
@ModelLoaderRegistry.register(
|
||||
base=BaseModelType.StableDiffusionXL, type=ModelType.ControlNet, format=ModelFormat.Checkpoint
|
||||
)
|
||||
class ControlNetLoader(GenericDiffusersLoader):
|
||||
"""Class to load ControlNet models."""
|
||||
|
||||
|
||||
@@ -10,6 +10,15 @@ from safetensors.torch import load_file
|
||||
from transformers import AutoConfig, AutoModelForTextEncoding, CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.flux.controlnet.xlabs_controlnet_flux import XLabsControlNetFlux
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.modules.autoencoder import AutoEncoder
|
||||
from invokeai.backend.flux.util import ae_params, params
|
||||
@@ -24,6 +33,8 @@ from invokeai.backend.model_manager import (
|
||||
from invokeai.backend.model_manager.config import (
|
||||
CheckpointConfigBase,
|
||||
CLIPEmbedDiffusersConfig,
|
||||
ControlNetCheckpointConfig,
|
||||
ControlNetDiffusersConfig,
|
||||
MainBnbQuantized4bCheckpointConfig,
|
||||
MainCheckpointConfig,
|
||||
MainGGUFCheckpointConfig,
|
||||
@@ -293,3 +304,51 @@ class FluxBnbQuantizednf4bCheckpointModel(ModelLoader):
|
||||
sd = convert_bundle_to_flux_transformer_checkpoint(sd)
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
|
||||
@ModelLoaderRegistry.register(base=BaseModelType.Flux, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
|
||||
class FluxControlnetModel(ModelLoader):
|
||||
"""Class to load FLUX ControlNet models."""
|
||||
|
||||
def _load_model(
|
||||
self,
|
||||
config: AnyModelConfig,
|
||||
submodel_type: Optional[SubModelType] = None,
|
||||
) -> AnyModel:
|
||||
if isinstance(config, ControlNetCheckpointConfig):
|
||||
model_path = Path(config.path)
|
||||
elif isinstance(config, ControlNetDiffusersConfig):
|
||||
# If this is a diffusers directory, we simply ignore the config file and load from the weight file.
|
||||
model_path = Path(config.path) / "diffusion_pytorch_model.safetensors"
|
||||
else:
|
||||
raise ValueError(f"Unexpected ControlNet model config type: {type(config)}")
|
||||
|
||||
sd = load_file(model_path)
|
||||
|
||||
# Detect the FLUX ControlNet model type from the state dict.
|
||||
if is_state_dict_xlabs_controlnet(sd):
|
||||
return self._load_xlabs_controlnet(sd)
|
||||
elif is_state_dict_instantx_controlnet(sd):
|
||||
return self._load_instantx_controlnet(sd)
|
||||
else:
|
||||
raise ValueError("Do not recognize the state dict as an XLabs or InstantX ControlNet model.")
|
||||
|
||||
def _load_xlabs_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
with accelerate.init_empty_weights():
|
||||
# HACK(ryand): Is it safe to assume dev here?
|
||||
model = XLabsControlNetFlux(params["flux-dev"])
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
def _load_instantx_controlnet(self, sd: dict[str, torch.Tensor]) -> AnyModel:
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with accelerate.init_empty_weights():
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model.load_state_dict(sd, assign=True)
|
||||
return model
|
||||
|
||||
@@ -10,6 +10,10 @@ from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from invokeai.backend.lora.conversions.flux_diffusers_lora_conversion_utils import (
|
||||
is_state_dict_likely_in_flux_diffusers_format,
|
||||
)
|
||||
@@ -116,6 +120,7 @@ class ModelProbe(object):
|
||||
"CLIPModel": ModelType.CLIPEmbed,
|
||||
"CLIPTextModel": ModelType.CLIPEmbed,
|
||||
"T5EncoderModel": ModelType.T5Encoder,
|
||||
"FluxControlNetModel": ModelType.ControlNet,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -255,7 +260,19 @@ class ModelProbe(object):
|
||||
# LoRA models, but as of the time of writing, we support Diffusers FLUX PEFT LoRA models.
|
||||
elif key.endswith(("to_k_lora.up.weight", "to_q_lora.down.weight", "lora_A.weight", "lora_B.weight")):
|
||||
return ModelType.LoRA
|
||||
elif key.startswith(("controlnet", "control_model", "input_blocks")):
|
||||
elif key.startswith(
|
||||
(
|
||||
"controlnet",
|
||||
"control_model",
|
||||
"input_blocks",
|
||||
# XLabs FLUX ControlNet models have keys starting with "controlnet_blocks."
|
||||
# For example: https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
# TODO(ryand): This is very fragile. XLabs FLUX ControlNet models also contain keys starting with
|
||||
# "double_blocks.", which we check for above. But, I'm afraid to modify this logic because it is so
|
||||
# delicate.
|
||||
"controlnet_blocks",
|
||||
)
|
||||
):
|
||||
return ModelType.ControlNet
|
||||
elif key.startswith(("image_proj.", "ip_adapter.")):
|
||||
return ModelType.IPAdapter
|
||||
@@ -438,6 +455,7 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"hed": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
"pose": "dw_openpose_image_processor",
|
||||
"mediapipe": "mediapipe_face_processor",
|
||||
@@ -449,7 +467,8 @@ MODEL_NAME_TO_PREPROCESSOR = {
|
||||
|
||||
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||
if k in model_name:
|
||||
model_name_lower = model_name.lower()
|
||||
if k in model_name_lower:
|
||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
||||
return None
|
||||
|
||||
@@ -623,6 +642,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
checkpoint = self.checkpoint
|
||||
if is_state_dict_xlabs_controlnet(checkpoint) or is_state_dict_instantx_controlnet(checkpoint):
|
||||
# TODO(ryand): Should I distinguish between XLabs, InstantX and other ControlNet models by implementing
|
||||
# get_format()?
|
||||
return BaseModelType.Flux
|
||||
|
||||
for key_name in (
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"controlnet_mid_block.bias",
|
||||
@@ -844,22 +868,19 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
raise InvalidModelConfigException(f"Cannot determine base type for {self.model_path}")
|
||||
with open(config_file, "r") as file:
|
||||
config = json.load(file)
|
||||
|
||||
if config.get("_class_name", None) == "FluxControlNetModel":
|
||||
return BaseModelType.Flux
|
||||
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
dimension = config["cross_attention_dim"]
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
return base_model
|
||||
if dimension == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
if dimension == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
if dimension == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise InvalidModelConfigException(f"Unable to determine model base for {self.model_path}")
|
||||
|
||||
|
||||
class LoRAFolderProbe(FolderProbeBase):
|
||||
|
||||
@@ -422,6 +422,13 @@ STARTER_MODELS: list[StarterModel] = [
|
||||
description="ControlNet weights trained on sdxl-1.0 with tiled image conditioning",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
StarterModel(
|
||||
name="FLUX.1-dev-Controlnet-Union-Pro",
|
||||
base=BaseModelType.Flux,
|
||||
source="Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro",
|
||||
description="A unified ControlNet for FLUX.1-dev model that supports 7 control modes, including canny (0), tile (1), depth (2), blur (3), pose (4), gray (5), low quality (6)",
|
||||
type=ModelType.ControlNet,
|
||||
),
|
||||
# endregion
|
||||
# region T2I Adapter
|
||||
StarterModel(
|
||||
|
||||
@@ -80,7 +80,6 @@ export const CanvasAddEntityButtons = memo(() => {
|
||||
justifyContent="flex-start"
|
||||
leftIcon={<PiPlusBold />}
|
||||
onClick={addControlLayer}
|
||||
isDisabled={isFLUX}
|
||||
>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</Button>
|
||||
|
||||
@@ -56,7 +56,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
|
||||
</MenuItem>
|
||||
</MenuGroup>
|
||||
<MenuGroup title={t('controlLayers.layer_other')}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer} isDisabled={isFLUX}>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addControlLayer}>
|
||||
{t('controlLayers.controlLayer')}
|
||||
</MenuItem>
|
||||
<MenuItem icon={<PiPlusBold />} onClick={addRasterLayer}>
|
||||
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
controlLayerModelChanged,
|
||||
controlLayerWeightChanged,
|
||||
} from 'features/controlLayers/store/canvasSlice';
|
||||
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
|
||||
import { selectCanvasSlice, selectEntityOrThrow } from 'features/controlLayers/store/selectors';
|
||||
import type { CanvasEntityIdentifier, ControlModeV2 } from 'features/controlLayers/store/types';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
@@ -42,6 +43,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
const entityIdentifier = useEntityIdentifierContext('control_layer');
|
||||
const controlAdapter = useControlLayerControlAdapter(entityIdentifier);
|
||||
const filter = useEntityFilter(entityIdentifier);
|
||||
const isFLUX = useAppSelector(selectIsFLUX);
|
||||
|
||||
const onChangeBeginEndStepPct = useCallback(
|
||||
(beginEndStepPct: [number, number]) => {
|
||||
@@ -117,7 +119,7 @@ export const ControlLayerControlAdapter = memo(() => {
|
||||
</Flex>
|
||||
<Weight weight={controlAdapter.weight} onChange={onChangeWeight} />
|
||||
<BeginEndStepPct beginEndStepPct={controlAdapter.beginEndStepPct} onChange={onChangeBeginEndStepPct} />
|
||||
{controlAdapter.type === 'controlnet' && (
|
||||
{controlAdapter.type === 'controlnet' && !isFLUX && (
|
||||
<ControlLayerControlAdapterControlMode
|
||||
controlMode={controlAdapter.controlMode}
|
||||
onChange={onChangeControlMode}
|
||||
|
||||
@@ -110,10 +110,10 @@ const addControlNetToGraph = (
|
||||
|
||||
const controlNet = g.addNode({
|
||||
id: `control_net_${id}`,
|
||||
type: 'controlnet',
|
||||
type: model.base === 'flux' ? 'flux_controlnet' : 'controlnet',
|
||||
begin_step_percent: beginEndStepPct[0],
|
||||
end_step_percent: beginEndStepPct[1],
|
||||
control_mode: controlMode,
|
||||
control_mode: model.base === 'flux' ? undefined : controlMode,
|
||||
resize_mode: 'just_resize',
|
||||
control_model: model,
|
||||
control_weight: weight,
|
||||
|
||||
@@ -19,6 +19,8 @@ import type { Invocation } from 'services/api/types';
|
||||
import { isNonRefinerMainModelConfig } from 'services/api/types';
|
||||
import { assert } from 'tsafe';
|
||||
|
||||
import { addControlNets } from './addControlAdapters';
|
||||
|
||||
const log = logger('system');
|
||||
|
||||
export const buildFLUXGraph = async (
|
||||
@@ -93,6 +95,7 @@ export const buildFLUXGraph = async (
|
||||
> = l2i;
|
||||
|
||||
g.addEdge(modelLoader, 'transformer', noise, 'transformer');
|
||||
g.addEdge(modelLoader, 'vae', noise, 'controlnet_vae');
|
||||
g.addEdge(modelLoader, 'vae', l2i, 'vae');
|
||||
|
||||
g.addEdge(modelLoader, 'clip', posCond, 'clip');
|
||||
@@ -177,6 +180,24 @@ export const buildFLUXGraph = async (
|
||||
);
|
||||
}
|
||||
|
||||
const controlNetCollector = g.addNode({
|
||||
type: 'collect',
|
||||
id: getPrefixedId('control_net_collector'),
|
||||
});
|
||||
const controlNetResult = await addControlNets(
|
||||
manager,
|
||||
canvas.controlLayers.entities,
|
||||
g,
|
||||
canvas.bbox.rect,
|
||||
controlNetCollector,
|
||||
modelConfig.base
|
||||
);
|
||||
if (controlNetResult.addedControlNets > 0) {
|
||||
g.addEdge(controlNetCollector, 'collection', noise, 'control');
|
||||
} else {
|
||||
g.deleteNode(controlNetCollector.id);
|
||||
}
|
||||
|
||||
if (state.system.shouldUseNSFWChecker) {
|
||||
canvasOutput = addNSFWChecker(g, canvasOutput);
|
||||
}
|
||||
|
||||
File diff suppressed because one or more lines are too long
30
scripts/extract_sd_keys_and_shapes.py
Normal file
30
scripts/extract_sd_keys_and_shapes.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
def extract_sd_keys_and_shapes(safetensors_file: str):
|
||||
sd = load_file(safetensors_file)
|
||||
|
||||
keys_to_shapes = {k: v.shape for k, v in sd.items()}
|
||||
|
||||
out_file = "keys_and_shapes.json"
|
||||
with open(out_file, "w") as f:
|
||||
json.dump(keys_to_shapes, f, indent=4)
|
||||
|
||||
print(f"Keys and shapes written to '{out_file}'.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Extracts the keys and shapes from the state dict in a safetensors file. Intended for creating "
|
||||
+ "dummy state dicts for use in unit tests."
|
||||
)
|
||||
parser.add_argument("safetensors_file", type=str, help="Path to the safetensors file.")
|
||||
args = parser.parse_args()
|
||||
extract_sd_keys_and_shapes(args.safetensors_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,374 @@
|
||||
# State dict keys and shapes for an InstantX FLUX ControlNet Union model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/4f32d6f2b220f8873d49bb8acc073e1df180c994/diffusion_pytorch_model.safetensors
|
||||
instantx_sd_shapes = {
|
||||
"context_embedder.bias": [3072],
|
||||
"context_embedder.weight": [3072, 4096],
|
||||
"controlnet_blocks.0.bias": [3072],
|
||||
"controlnet_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_blocks.1.bias": [3072],
|
||||
"controlnet_blocks.1.weight": [3072, 3072],
|
||||
"controlnet_blocks.2.bias": [3072],
|
||||
"controlnet_blocks.2.weight": [3072, 3072],
|
||||
"controlnet_blocks.3.bias": [3072],
|
||||
"controlnet_blocks.3.weight": [3072, 3072],
|
||||
"controlnet_blocks.4.bias": [3072],
|
||||
"controlnet_blocks.4.weight": [3072, 3072],
|
||||
"controlnet_mode_embedder.weight": [10, 3072],
|
||||
"controlnet_single_blocks.0.bias": [3072],
|
||||
"controlnet_single_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.1.bias": [3072],
|
||||
"controlnet_single_blocks.1.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.2.bias": [3072],
|
||||
"controlnet_single_blocks.2.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.3.bias": [3072],
|
||||
"controlnet_single_blocks.3.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.4.bias": [3072],
|
||||
"controlnet_single_blocks.4.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.5.bias": [3072],
|
||||
"controlnet_single_blocks.5.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.6.bias": [3072],
|
||||
"controlnet_single_blocks.6.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.7.bias": [3072],
|
||||
"controlnet_single_blocks.7.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.8.bias": [3072],
|
||||
"controlnet_single_blocks.8.weight": [3072, 3072],
|
||||
"controlnet_single_blocks.9.bias": [3072],
|
||||
"controlnet_single_blocks.9.weight": [3072, 3072],
|
||||
"controlnet_x_embedder.bias": [3072],
|
||||
"controlnet_x_embedder.weight": [3072, 64],
|
||||
"single_transformer_blocks.0.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.0.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.0.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.0.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.0.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.0.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.0.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.0.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.0.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.0.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.1.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.1.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.1.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.1.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.1.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.1.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.1.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.1.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.1.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.1.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.2.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.2.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.2.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.2.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.2.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.2.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.2.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.2.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.2.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.2.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.3.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.3.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.3.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.3.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.3.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.3.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.3.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.3.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.3.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.3.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.4.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.4.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.4.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.4.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.4.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.4.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.4.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.4.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.4.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.4.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.5.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.5.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.5.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.5.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.5.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.5.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.5.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.5.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.5.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.5.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.6.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.6.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.6.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.6.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.6.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.6.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.6.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.6.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.6.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.6.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.7.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.7.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.7.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.7.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.7.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.7.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.7.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.7.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.7.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.7.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.8.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.8.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.8.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.8.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.8.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.8.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.8.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.8.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.8.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.8.proj_out.weight": [3072, 15360],
|
||||
"single_transformer_blocks.9.attn.norm_k.weight": [128],
|
||||
"single_transformer_blocks.9.attn.norm_q.weight": [128],
|
||||
"single_transformer_blocks.9.attn.to_k.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_k.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.attn.to_q.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_q.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.attn.to_v.bias": [3072],
|
||||
"single_transformer_blocks.9.attn.to_v.weight": [3072, 3072],
|
||||
"single_transformer_blocks.9.norm.linear.bias": [9216],
|
||||
"single_transformer_blocks.9.norm.linear.weight": [9216, 3072],
|
||||
"single_transformer_blocks.9.proj_mlp.bias": [12288],
|
||||
"single_transformer_blocks.9.proj_mlp.weight": [12288, 3072],
|
||||
"single_transformer_blocks.9.proj_out.bias": [3072],
|
||||
"single_transformer_blocks.9.proj_out.weight": [3072, 15360],
|
||||
"time_text_embed.guidance_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.guidance_embedder.linear_1.weight": [3072, 256],
|
||||
"time_text_embed.guidance_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.guidance_embedder.linear_2.weight": [3072, 3072],
|
||||
"time_text_embed.text_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.text_embedder.linear_1.weight": [3072, 768],
|
||||
"time_text_embed.text_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.text_embedder.linear_2.weight": [3072, 3072],
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": [3072],
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": [3072, 256],
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": [3072],
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.0.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.0.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.0.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.0.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.0.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.0.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.0.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.0.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.0.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.0.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.0.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.0.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.0.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.0.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.0.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.0.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.0.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.1.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.1.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.1.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.1.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.1.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.1.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.1.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.1.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.1.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.1.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.1.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.1.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.1.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.1.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.1.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.1.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.1.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.1.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.2.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.2.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.2.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.2.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.2.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.2.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.2.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.2.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.2.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.2.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.2.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.2.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.2.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.2.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.2.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.2.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.2.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.2.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.3.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.3.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.3.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.3.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.3.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.3.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.3.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.3.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.3.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.3.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.3.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.3.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.3.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.3.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.3.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.3.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.3.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.3.norm1_context.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.4.attn.add_k_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_k_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.add_q_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_q_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.add_v_proj.bias": [3072],
|
||||
"transformer_blocks.4.attn.add_v_proj.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.norm_added_k.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_added_q.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_k.weight": [128],
|
||||
"transformer_blocks.4.attn.norm_q.weight": [128],
|
||||
"transformer_blocks.4.attn.to_add_out.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_add_out.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_k.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_k.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_out.0.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_out.0.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_q.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_q.weight": [3072, 3072],
|
||||
"transformer_blocks.4.attn.to_v.bias": [3072],
|
||||
"transformer_blocks.4.attn.to_v.weight": [3072, 3072],
|
||||
"transformer_blocks.4.ff.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.4.ff.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.4.ff.net.2.bias": [3072],
|
||||
"transformer_blocks.4.ff.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.4.ff_context.net.0.proj.bias": [12288],
|
||||
"transformer_blocks.4.ff_context.net.0.proj.weight": [12288, 3072],
|
||||
"transformer_blocks.4.ff_context.net.2.bias": [3072],
|
||||
"transformer_blocks.4.ff_context.net.2.weight": [3072, 12288],
|
||||
"transformer_blocks.4.norm1.linear.bias": [18432],
|
||||
"transformer_blocks.4.norm1.linear.weight": [18432, 3072],
|
||||
"transformer_blocks.4.norm1_context.linear.bias": [18432],
|
||||
"transformer_blocks.4.norm1_context.linear.weight": [18432, 3072],
|
||||
"x_embedder.bias": [3072],
|
||||
"x_embedder.weight": [3072, 64],
|
||||
}
|
||||
|
||||
|
||||
# InstantX FLUX ControlNet config for unit tests.
|
||||
# Copied from https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Union/blob/main/config.json
|
||||
instantx_config = {
|
||||
"_class_name": "FluxControlNetModel",
|
||||
"_diffusers_version": "0.30.0.dev0",
|
||||
"_name_or_path": "/mnt/wangqixun/",
|
||||
"attention_head_dim": 128,
|
||||
"axes_dims_rope": [16, 56, 56],
|
||||
"guidance_embeds": True,
|
||||
"in_channels": 64,
|
||||
"joint_attention_dim": 4096,
|
||||
"num_attention_heads": 24,
|
||||
"num_layers": 5,
|
||||
"num_mode": 10,
|
||||
"num_single_layers": 10,
|
||||
"patch_size": 1,
|
||||
"pooled_projection_dim": 768,
|
||||
}
|
||||
108
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
108
tests/backend/flux/controlnet/test_state_dict_utils.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.controlnet.instantx_controlnet_flux import InstantXControlNetFlux
|
||||
from invokeai.backend.flux.controlnet.state_dict_utils import (
|
||||
convert_diffusers_instantx_state_dict_to_bfl_format,
|
||||
infer_flux_params_from_state_dict,
|
||||
infer_instantx_num_control_modes_from_state_dict,
|
||||
is_state_dict_instantx_controlnet,
|
||||
is_state_dict_xlabs_controlnet,
|
||||
)
|
||||
from tests.backend.flux.controlnet.instantx_flux_controlnet_state_dict import instantx_config, instantx_sd_shapes
|
||||
from tests.backend.flux.controlnet.xlabs_flux_controlnet_state_dict import xlabs_sd_shapes
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_shapes", "expected"],
|
||||
[
|
||||
(xlabs_sd_shapes, True),
|
||||
(instantx_sd_shapes, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_xlabs_controlnet(sd_shapes: dict[str, list[int]], expected: bool):
|
||||
sd = {k: None for k in sd_shapes}
|
||||
assert is_state_dict_xlabs_controlnet(sd) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["sd_keys", "expected"],
|
||||
[
|
||||
(instantx_sd_shapes, True),
|
||||
(xlabs_sd_shapes, False),
|
||||
(["foo"], False),
|
||||
],
|
||||
)
|
||||
def test_is_state_dict_instantx_controlnet(sd_keys: list[str], expected: bool):
|
||||
sd = {k: None for k in sd_keys}
|
||||
assert is_state_dict_instantx_controlnet(sd) == expected
|
||||
|
||||
|
||||
def test_convert_diffusers_instantx_state_dict_to_bfl_format():
|
||||
"""Smoke test convert_diffusers_instantx_state_dict_to_bfl_format() to ensure that it handles all of the keys."""
|
||||
sd = {k: torch.zeros(1) for k in instantx_sd_shapes}
|
||||
bfl_sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
assert bfl_sd is not None
|
||||
|
||||
|
||||
# TODO(ryand): Figure out why some tests in this file are failing on the MacOS CI runners. It seems to be related to
|
||||
# using the meta device. I can't reproduce the issue on my local MacOS system.
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_infer_flux_params_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
|
||||
assert flux_params.in_channels == instantx_config["in_channels"]
|
||||
assert flux_params.vec_in_dim == instantx_config["pooled_projection_dim"]
|
||||
assert flux_params.context_in_dim == instantx_config["joint_attention_dim"]
|
||||
assert flux_params.hidden_size // flux_params.num_heads == instantx_config["attention_head_dim"]
|
||||
assert flux_params.num_heads == instantx_config["num_attention_heads"]
|
||||
assert flux_params.mlp_ratio == 4
|
||||
assert flux_params.depth == instantx_config["num_layers"]
|
||||
assert flux_params.depth_single_blocks == instantx_config["num_single_layers"]
|
||||
assert flux_params.axes_dim == instantx_config["axes_dims_rope"]
|
||||
assert flux_params.theta == 10000
|
||||
assert flux_params.qkv_bias
|
||||
assert flux_params.guidance_embed == instantx_config["guidance_embeds"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_infer_instantx_num_control_modes_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
assert num_control_modes == instantx_config["num_mode"]
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="Skipping on macOS")
|
||||
def test_load_instantx_from_state_dict():
|
||||
# Construct a dummy state_dict with tensors of the correct shape on the meta device.
|
||||
with torch.device("meta"):
|
||||
sd = {k: torch.zeros(v) for k, v in instantx_sd_shapes.items()}
|
||||
|
||||
sd = convert_diffusers_instantx_state_dict_to_bfl_format(sd)
|
||||
flux_params = infer_flux_params_from_state_dict(sd)
|
||||
num_control_modes = infer_instantx_num_control_modes_from_state_dict(sd)
|
||||
|
||||
with torch.device("meta"):
|
||||
model = InstantXControlNetFlux(flux_params, num_control_modes)
|
||||
|
||||
model_sd = model.state_dict()
|
||||
|
||||
assert set(model_sd.keys()) == set(sd.keys())
|
||||
for key, tensor in model_sd.items():
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert tensor.shape == sd[key].shape
|
||||
@@ -0,0 +1,91 @@
|
||||
# State dict keys and shapes for an XLabs FLUX ControlNet model. Intended to be used for unit tests.
|
||||
# These keys were extracted from:
|
||||
# https://huggingface.co/XLabs-AI/flux-controlnet-collections/blob/86ab1e915a389d5857135c00e0d350e9e38a9048/flux-canny-controlnet_v2.safetensors
|
||||
xlabs_sd_shapes = {
|
||||
"controlnet_blocks.0.bias": [3072],
|
||||
"controlnet_blocks.0.weight": [3072, 3072],
|
||||
"controlnet_blocks.1.bias": [3072],
|
||||
"controlnet_blocks.1.weight": [3072, 3072],
|
||||
"double_blocks.0.img_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.0.img_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.0.img_attn.proj.bias": [3072],
|
||||
"double_blocks.0.img_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.0.img_attn.qkv.bias": [9216],
|
||||
"double_blocks.0.img_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.0.img_mlp.0.bias": [12288],
|
||||
"double_blocks.0.img_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.0.img_mlp.2.bias": [3072],
|
||||
"double_blocks.0.img_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.0.img_mod.lin.bias": [18432],
|
||||
"double_blocks.0.img_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.0.txt_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.0.txt_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.0.txt_attn.proj.bias": [3072],
|
||||
"double_blocks.0.txt_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.0.txt_attn.qkv.bias": [9216],
|
||||
"double_blocks.0.txt_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.0.txt_mlp.0.bias": [12288],
|
||||
"double_blocks.0.txt_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.0.txt_mlp.2.bias": [3072],
|
||||
"double_blocks.0.txt_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.0.txt_mod.lin.bias": [18432],
|
||||
"double_blocks.0.txt_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.1.img_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.1.img_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.1.img_attn.proj.bias": [3072],
|
||||
"double_blocks.1.img_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.1.img_attn.qkv.bias": [9216],
|
||||
"double_blocks.1.img_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.1.img_mlp.0.bias": [12288],
|
||||
"double_blocks.1.img_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.1.img_mlp.2.bias": [3072],
|
||||
"double_blocks.1.img_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.1.img_mod.lin.bias": [18432],
|
||||
"double_blocks.1.img_mod.lin.weight": [18432, 3072],
|
||||
"double_blocks.1.txt_attn.norm.key_norm.scale": [128],
|
||||
"double_blocks.1.txt_attn.norm.query_norm.scale": [128],
|
||||
"double_blocks.1.txt_attn.proj.bias": [3072],
|
||||
"double_blocks.1.txt_attn.proj.weight": [3072, 3072],
|
||||
"double_blocks.1.txt_attn.qkv.bias": [9216],
|
||||
"double_blocks.1.txt_attn.qkv.weight": [9216, 3072],
|
||||
"double_blocks.1.txt_mlp.0.bias": [12288],
|
||||
"double_blocks.1.txt_mlp.0.weight": [12288, 3072],
|
||||
"double_blocks.1.txt_mlp.2.bias": [3072],
|
||||
"double_blocks.1.txt_mlp.2.weight": [3072, 12288],
|
||||
"double_blocks.1.txt_mod.lin.bias": [18432],
|
||||
"double_blocks.1.txt_mod.lin.weight": [18432, 3072],
|
||||
"guidance_in.in_layer.bias": [3072],
|
||||
"guidance_in.in_layer.weight": [3072, 256],
|
||||
"guidance_in.out_layer.bias": [3072],
|
||||
"guidance_in.out_layer.weight": [3072, 3072],
|
||||
"img_in.bias": [3072],
|
||||
"img_in.weight": [3072, 64],
|
||||
"input_hint_block.0.bias": [16],
|
||||
"input_hint_block.0.weight": [16, 3, 3, 3],
|
||||
"input_hint_block.10.bias": [16],
|
||||
"input_hint_block.10.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.12.bias": [16],
|
||||
"input_hint_block.12.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.14.bias": [16],
|
||||
"input_hint_block.14.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.2.bias": [16],
|
||||
"input_hint_block.2.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.4.bias": [16],
|
||||
"input_hint_block.4.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.6.bias": [16],
|
||||
"input_hint_block.6.weight": [16, 16, 3, 3],
|
||||
"input_hint_block.8.bias": [16],
|
||||
"input_hint_block.8.weight": [16, 16, 3, 3],
|
||||
"pos_embed_input.bias": [3072],
|
||||
"pos_embed_input.weight": [3072, 64],
|
||||
"time_in.in_layer.bias": [3072],
|
||||
"time_in.in_layer.weight": [3072, 256],
|
||||
"time_in.out_layer.bias": [3072],
|
||||
"time_in.out_layer.weight": [3072, 3072],
|
||||
"txt_in.bias": [3072],
|
||||
"txt_in.weight": [3072, 4096],
|
||||
"vector_in.in_layer.bias": [3072],
|
||||
"vector_in.in_layer.weight": [3072, 768],
|
||||
"vector_in.out_layer.bias": [3072],
|
||||
"vector_in.out_layer.weight": [3072, 3072],
|
||||
}
|
||||
Reference in New Issue
Block a user