diff --git a/invokeai/backend/flux/controlnet/state_dict_utils.py b/invokeai/backend/flux/controlnet/state_dict_utils.py index 2575b4c56b..1fc87fb4be 100644 --- a/invokeai/backend/flux/controlnet/state_dict_utils.py +++ b/invokeai/backend/flux/controlnet/state_dict_utils.py @@ -43,87 +43,129 @@ def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool: return False -def _fuse_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor: +def _fuse_weights(*t: torch.Tensor) -> torch.Tensor: + """Fuse weights along dimension 0. + + Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format. + """ # TODO(ryand): Double check dim=0 is correct. - return torch.cat((q, k, v), dim=0) + return torch.cat(t, dim=0) def _convert_flux_double_block_sd_from_diffusers_to_bfl_format( sd: Dict[str, torch.Tensor], double_block_index: int ) -> Dict[str, torch.Tensor]: """Convert the state dict for a double block from diffusers format to BFL format.""" - - # double_blocks.0.img_attn.norm.key_norm.scale - # double_blocks.0.img_attn.norm.query_norm.scale - # double_blocks.0.img_attn.proj.bias - # double_blocks.0.img_attn.proj.weight - # double_blocks.0.img_attn.qkv.bias - # double_blocks.0.img_attn.qkv.weight - # double_blocks.0.img_mlp.0.bias - # double_blocks.0.img_mlp.0.weight - # double_blocks.0.img_mlp.2.bias - # double_blocks.0.img_mlp.2.weight - # double_blocks.0.img_mod.lin.bias - # double_blocks.0.img_mod.lin.weight - # double_blocks.0.txt_attn.norm.key_norm.scale - # double_blocks.0.txt_attn.norm.query_norm.scale - # double_blocks.0.txt_attn.proj.bias - # double_blocks.0.txt_attn.proj.weight - # double_blocks.0.txt_attn.qkv.bias - # double_blocks.0.txt_attn.qkv.weight - # double_blocks.0.txt_mlp.0.bias - # double_blocks.0.txt_mlp.0.weight - # double_blocks.0.txt_mlp.2.bias - # double_blocks.0.txt_mlp.2.weight - # double_blocks.0.txt_mod.lin.bias - # double_blocks.0.txt_mod.lin.weight - - # "transformer_blocks.0.attn.add_k_proj.bias", - # "transformer_blocks.0.attn.add_k_proj.weight", - # "transformer_blocks.0.attn.add_q_proj.bias", - # "transformer_blocks.0.attn.add_q_proj.weight", - # "transformer_blocks.0.attn.add_v_proj.bias", - # "transformer_blocks.0.attn.add_v_proj.weight", - # "transformer_blocks.0.attn.norm_added_k.weight", - # "transformer_blocks.0.attn.norm_added_q.weight", - # "transformer_blocks.0.attn.norm_k.weight", - # "transformer_blocks.0.attn.norm_q.weight", - # "transformer_blocks.0.attn.to_add_out.bias", - # "transformer_blocks.0.attn.to_add_out.weight", - # "transformer_blocks.0.attn.to_k.bias", - # "transformer_blocks.0.attn.to_k.weight", - # "transformer_blocks.0.attn.to_out.0.bias", - # "transformer_blocks.0.attn.to_out.0.weight", - # "transformer_blocks.0.attn.to_q.bias", - # "transformer_blocks.0.attn.to_q.weight", - # "transformer_blocks.0.attn.to_v.bias", - # "transformer_blocks.0.attn.to_v.weight", - # "transformer_blocks.0.ff.net.0.proj.bias", - # "transformer_blocks.0.ff.net.0.proj.weight", - # "transformer_blocks.0.ff.net.2.bias", - # "transformer_blocks.0.ff.net.2.weight", - # "transformer_blocks.0.ff_context.net.0.proj.bias", - # "transformer_blocks.0.ff_context.net.0.proj.weight", - # "transformer_blocks.0.ff_context.net.2.bias", - # "transformer_blocks.0.ff_context.net.2.weight", - # "transformer_blocks.0.norm1.linear.bias", - # "transformer_blocks.0.norm1.linear.weight", - # "transformer_blocks.0.norm1_context.linear.bias", - # "transformer_blocks.0.norm1_context.linear.weight", + to_prefix = f"double_blocks.{double_block_index}." + from_prefix = f"transformer_blocks.{double_block_index}." new_sd: dict[str, torch.Tensor] = {} - new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.bias"] = _fuse_qkv( - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.bias"), - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.bias"), - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.bias"), + # Check one key to determine if this block exists. + if f"{from_prefix}.attn.add_q_proj.bias" not in sd: + return new_sd + + # txt_attn.qkv + new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.add_q_proj.bias"), + sd.pop(f"{from_prefix}.attn.add_k_proj.bias"), + sd.pop(f"{from_prefix}.attn.add_v_proj.bias"), ) - new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.weight"] = _fuse_qkv( - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.weight"), - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.weight"), - sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.weight"), + new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.add_q_proj.weight"), + sd.pop(f"{from_prefix}.attn.add_k_proj.weight"), + sd.pop(f"{from_prefix}.attn.add_v_proj.weight"), ) + # img_attn.qkv + new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.to_q.bias"), + sd.pop(f"{from_prefix}.attn.to_k.bias"), + sd.pop(f"{from_prefix}.attn.to_v.bias"), + ) + new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.to_q.weight"), + sd.pop(f"{from_prefix}.attn.to_k.weight"), + sd.pop(f"{from_prefix}.attn.to_v.weight"), + ) + + # Handle basic 1-to-1 key conversions. + key_map = { + # img_attn + "attn.norm_k.weight": "img_attn.norm.key_norm.scale", + "attn.norm_q.weight": "img_attn.norm.query_norm.scale", + "attn.to_out.0.weight": "img_attn.proj.weight", + "attn.to_out.0.bias": "img_attn.proj.bias", + # img_mlp + "ff.net.0.proj.weight": "img_mlp.0.weight", + "ff.net.0.proj.bias": "img_mlp.0.bias", + "ff.net.2.weight": "img_mlp.2.weight", + "ff.net.2.bias": "img_mlp.2.bias", + # img_mod + "norm1.linear.weight": "img_mod.lin.weight", + "norm1.linear.bias": "img_mod.lin.bias", + # txt_attn + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + "attn.to_add_out.weight": "txt_attn.proj.weight", + "attn.to_add_out.bias": "txt_attn.proj.bias", + # txt_mlp + "ff_context.net.0.proj.weight": "txt_mlp.0.weight", + "ff_context.net.0.proj.bias": "txt_mlp.0.bias", + "ff_context.net.2.weight": "txt_mlp.2.weight", + "ff_context.net.2.bias": "txt_mlp.2.bias", + # txt_mod + "norm1_context.linear.weight": "txt_mod.lin.weight", + "norm1_context.linear.bias": "txt_mod.lin.bias", + } + for from_key, to_key in key_map.items(): + new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}") + + return new_sd + + +def _convert_flux_single_block_sd_from_diffusers_to_bfl_format( + sd: Dict[str, torch.Tensor], single_block_index: int +) -> Dict[str, torch.Tensor]: + """Convert the state dict for a single block from diffusers format to BFL format.""" + to_prefix = f"single_blocks.{single_block_index}." + from_prefix = f"single_transformer_blocks.{single_block_index}." + + new_sd: dict[str, torch.Tensor] = {} + + # Check one key to determine if this block exists. + if f"{from_prefix}.attn.to_q.bias" not in sd: + return new_sd + + # linear1 (qkv) + new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.to_q.bias"), + sd.pop(f"{from_prefix}.attn.to_k.bias"), + sd.pop(f"{from_prefix}.attn.to_v.bias"), + sd.pop(f"{from_prefix}.proj_mlp.bias"), + ) + new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights( + sd.pop(f"{from_prefix}.attn.to_q.weight"), + sd.pop(f"{from_prefix}.attn.to_k.weight"), + sd.pop(f"{from_prefix}.attn.to_v.weight"), + sd.pop(f"{from_prefix}.proj_mlp.weight"), + ) + + # Handle basic 1-to-1 key conversions. + key_map = { + # linear2 + "proj_out.weight": "linear2.weight", + "proj_out.bias": "linear2.bias", + # modulation + "norm.linear.weight": "modulation.lin.weight", + "norm.linear.bias": "modulation.lin.bias", + # norm + "attn.norm_k.weight": "norm.key_norm.scale", + "attn.norm_q.weight": "norm.query_norm.scale", + } + for from_key, to_key in key_map.items(): + new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}") + return new_sd @@ -181,7 +223,13 @@ def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tens block_index += 1 # Handle the single_blocks. - ... + block_index = 0 + while True: + converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index) + if len(converted_singe_block_sd) == 0: + break + new_sd.update(converted_singe_block_sd) + block_index += 1 # Transfer controlnet keys as-is. for k in sd: