mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
WIP - implement convert_diffusers_instantx_state_dict_to_bfl_format(...).
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def is_state_dict_xlabs_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
"""Is the state dict for an XLabs ControlNet model?
|
||||
@@ -39,3 +41,153 @@ def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
if expected_keys.issubset(sd.keys()):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _fuse_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
# TODO(ryand): Double check dim=0 is correct.
|
||||
return torch.cat((q, k, v), dim=0)
|
||||
|
||||
|
||||
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], double_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
||||
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
# double_blocks.0.img_attn.norm.query_norm.scale
|
||||
# double_blocks.0.img_attn.proj.bias
|
||||
# double_blocks.0.img_attn.proj.weight
|
||||
# double_blocks.0.img_attn.qkv.bias
|
||||
# double_blocks.0.img_attn.qkv.weight
|
||||
# double_blocks.0.img_mlp.0.bias
|
||||
# double_blocks.0.img_mlp.0.weight
|
||||
# double_blocks.0.img_mlp.2.bias
|
||||
# double_blocks.0.img_mlp.2.weight
|
||||
# double_blocks.0.img_mod.lin.bias
|
||||
# double_blocks.0.img_mod.lin.weight
|
||||
# double_blocks.0.txt_attn.norm.key_norm.scale
|
||||
# double_blocks.0.txt_attn.norm.query_norm.scale
|
||||
# double_blocks.0.txt_attn.proj.bias
|
||||
# double_blocks.0.txt_attn.proj.weight
|
||||
# double_blocks.0.txt_attn.qkv.bias
|
||||
# double_blocks.0.txt_attn.qkv.weight
|
||||
# double_blocks.0.txt_mlp.0.bias
|
||||
# double_blocks.0.txt_mlp.0.weight
|
||||
# double_blocks.0.txt_mlp.2.bias
|
||||
# double_blocks.0.txt_mlp.2.weight
|
||||
# double_blocks.0.txt_mod.lin.bias
|
||||
# double_blocks.0.txt_mod.lin.weight
|
||||
|
||||
# "transformer_blocks.0.attn.add_k_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_k_proj.weight",
|
||||
# "transformer_blocks.0.attn.add_q_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_q_proj.weight",
|
||||
# "transformer_blocks.0.attn.add_v_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_v_proj.weight",
|
||||
# "transformer_blocks.0.attn.norm_added_k.weight",
|
||||
# "transformer_blocks.0.attn.norm_added_q.weight",
|
||||
# "transformer_blocks.0.attn.norm_k.weight",
|
||||
# "transformer_blocks.0.attn.norm_q.weight",
|
||||
# "transformer_blocks.0.attn.to_add_out.bias",
|
||||
# "transformer_blocks.0.attn.to_add_out.weight",
|
||||
# "transformer_blocks.0.attn.to_k.bias",
|
||||
# "transformer_blocks.0.attn.to_k.weight",
|
||||
# "transformer_blocks.0.attn.to_out.0.bias",
|
||||
# "transformer_blocks.0.attn.to_out.0.weight",
|
||||
# "transformer_blocks.0.attn.to_q.bias",
|
||||
# "transformer_blocks.0.attn.to_q.weight",
|
||||
# "transformer_blocks.0.attn.to_v.bias",
|
||||
# "transformer_blocks.0.attn.to_v.weight",
|
||||
# "transformer_blocks.0.ff.net.0.proj.bias",
|
||||
# "transformer_blocks.0.ff.net.0.proj.weight",
|
||||
# "transformer_blocks.0.ff.net.2.bias",
|
||||
# "transformer_blocks.0.ff.net.2.weight",
|
||||
# "transformer_blocks.0.ff_context.net.0.proj.bias",
|
||||
# "transformer_blocks.0.ff_context.net.0.proj.weight",
|
||||
# "transformer_blocks.0.ff_context.net.2.bias",
|
||||
# "transformer_blocks.0.ff_context.net.2.weight",
|
||||
# "transformer_blocks.0.norm1.linear.bias",
|
||||
# "transformer_blocks.0.norm1.linear.weight",
|
||||
# "transformer_blocks.0.norm1_context.linear.bias",
|
||||
# "transformer_blocks.0.norm1_context.linear.weight",
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.bias"] = _fuse_qkv(
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.bias"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.bias"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.bias"),
|
||||
)
|
||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.weight"] = _fuse_qkv(
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.weight"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.weight"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.weight"),
|
||||
)
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Convert an InstantX ControlNet state dict to the format that can be loaded by our internal
|
||||
DiffusersControlNetFlux.
|
||||
|
||||
The original InstantX ControlNet model was developed to be used in diffusers. We have ported the original
|
||||
implementation to DiffusersControlNetFlux to make it compatible with BFL-style models. This function converts the
|
||||
original state dict to the format expected by DiffusersControlNetFlux.
|
||||
"""
|
||||
# Shallow copy sd so that we can pop keys from it without modifying the original.
|
||||
sd = sd.copy()
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
basic_key_map = {
|
||||
# Base model keys.
|
||||
# ----------------
|
||||
# txt_in keys.
|
||||
"context_embedder.bias": "txt_in.bias",
|
||||
"context_embedder.weight": "txt_in.weight",
|
||||
# guidance_in MLPEmbedder keys.
|
||||
"time_text_embed.guidance_embedder.linear_1.bias": "guidance_in.in_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_1.weight": "guidance_in.in_layer.weight",
|
||||
"time_text_embed.guidance_embedder.linear_2.bias": "guidance_in.out_layer.bias",
|
||||
"time_text_embed.guidance_embedder.linear_2.weight": "guidance_in.out_layer.weight",
|
||||
# vector_in MLPEmbedder keys.
|
||||
"time_text_embed.text_embedder.linear_1.bias": "vector_in.in_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_1.weight": "vector_in.in_layer.weight",
|
||||
"time_text_embed.text_embedder.linear_2.bias": "vector_in.out_layer.bias",
|
||||
"time_text_embed.text_embedder.linear_2.weight": "vector_in.out_layer.weight",
|
||||
# time_in MLPEmbedder keys.
|
||||
"time_text_embed.timestep_embedder.linear_1.bias": "time_in.in_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_1.weight": "time_in.in_layer.weight",
|
||||
"time_text_embed.timestep_embedder.linear_2.bias": "time_in.out_layer.bias",
|
||||
"time_text_embed.timestep_embedder.linear_2.weight": "time_in.out_layer.weight",
|
||||
# img_in keys.
|
||||
"x_embedder.bias": "img_in.bias",
|
||||
"x_embedder.weight": "img_in.weight",
|
||||
}
|
||||
for old_key, new_key in basic_key_map.items():
|
||||
v = sd.pop(old_key, None)
|
||||
if v is not None:
|
||||
new_sd[new_key] = v
|
||||
|
||||
# Handle the double_blocks.
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_double_block_sd = _convert_flux_double_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_double_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_double_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Handle the single_blocks.
|
||||
...
|
||||
|
||||
# Transfer controlnet keys as-is.
|
||||
for k in sd:
|
||||
if k.startswith("controlnet_"):
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
# Assert that all keys have been handled.
|
||||
assert len(sd) == 0
|
||||
return new_sd
|
||||
|
||||
Reference in New Issue
Block a user