diff --git a/invokeai/backend/flux/controlnet/state_dict_utils.py b/invokeai/backend/flux/controlnet/state_dict_utils.py index 98ec3d1250..daf79ef327 100644 --- a/invokeai/backend/flux/controlnet/state_dict_utils.py +++ b/invokeai/backend/flux/controlnet/state_dict_utils.py @@ -2,6 +2,8 @@ from typing import Any, Dict import torch +from invokeai.backend.flux.model import FluxParams + def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool: """Is the state dict for an XLabs ControlNet model? @@ -239,3 +241,57 @@ def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tens # Assert that all keys have been handled. assert len(sd) == 0 return new_sd + + +def infer_flux_params_from_state_dict(sd: Dict[str, torch.Tensor]) -> FluxParams: + """Infer the FluxParams from the shape of a FLUX state dict. When a model is distributed in diffusers format, this + information is all contained in the config.json file that accompanies the model. However, being apple to infer the + params from the state dict enables us to load models (e.g. an InstantX ControlNet) from a single weight file. + """ + hidden_size = sd["img_in.weight"].shape[0] + mlp_hidden_dim = sd["double_blocks.0.img_mlp.0.weight"].shape[0] + # mlp_ratio is a float, but we treat it as an int here to avoid having to think about possible flost precision + # issues. In practice, mlp_ratio is usually 4. + mlp_ratio = mlp_hidden_dim // hidden_size + + head_dim = sd["double_blocks.0.img_attn.norm.query_norm.scale"].shape[0] + num_heads = hidden_size // head_dim + + # Count the number of double blocks. + double_block_index = 0 + while f"double_blocks.{double_block_index}.img_attn.qkv.weight" in sd: + double_block_index += 1 + + # Count the number of single blocks. + single_block_index = 0 + while f"single_blocks.{single_block_index}.linear1.weight" in sd: + single_block_index += 1 + + return FluxParams( + in_channels=sd["img_in.weight"].shape[1], + vec_in_dim=sd["vector_in.in_layer.weight"].shape[1], + context_in_dim=sd["txt_in.weight"].shape[1], + hidden_size=hidden_size, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + depth=double_block_index, + depth_single_blocks=single_block_index, + # axes_dim cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models. + axes_dim=[16, 56, 56], + # theta cannot be inferred from the state dict. The hard-coded value is correct for dev/schnell models. + theta=10_000, + qkv_bias="double_blocks.0.img_attn.qkv.bias" in sd, + guidance_embed="guidance_in.in_layer.weight" in sd, + ) + + +def infer_instantx_num_control_modes_from_state_dict(sd: Dict[str, torch.Tensor]) -> int | None: + """Infer the number of ControlNet Union modes from the shape of a InstantX ControlNet state dict. + + Returns None if the model is not a ControlNet Union model. Otherwise returns the number of modes. + """ + mode_embedder_key = "controlnet_mode_embedder.weight" + if mode_embedder_key not in sd: + return None + + return sd[mode_embedder_key].shape[0]