Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_format(...).

This commit is contained in:
Ryan Dick
2024-10-07 14:28:01 +00:00
committed by Kent Keirsey
parent 76f4766324
commit a9e7ecad49

View File

@@ -43,87 +43,129 @@ def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
return False
def _fuse_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
"""Fuse weights along dimension 0.
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
"""
# TODO(ryand): Double check dim=0 is correct.
return torch.cat((q, k, v), dim=0)
return torch.cat(t, dim=0)
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], double_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a double block from diffusers format to BFL format."""
# double_blocks.0.img_attn.norm.key_norm.scale
# double_blocks.0.img_attn.norm.query_norm.scale
# double_blocks.0.img_attn.proj.bias
# double_blocks.0.img_attn.proj.weight
# double_blocks.0.img_attn.qkv.bias
# double_blocks.0.img_attn.qkv.weight
# double_blocks.0.img_mlp.0.bias
# double_blocks.0.img_mlp.0.weight
# double_blocks.0.img_mlp.2.bias
# double_blocks.0.img_mlp.2.weight
# double_blocks.0.img_mod.lin.bias
# double_blocks.0.img_mod.lin.weight
# double_blocks.0.txt_attn.norm.key_norm.scale
# double_blocks.0.txt_attn.norm.query_norm.scale
# double_blocks.0.txt_attn.proj.bias
# double_blocks.0.txt_attn.proj.weight
# double_blocks.0.txt_attn.qkv.bias
# double_blocks.0.txt_attn.qkv.weight
# double_blocks.0.txt_mlp.0.bias
# double_blocks.0.txt_mlp.0.weight
# double_blocks.0.txt_mlp.2.bias
# double_blocks.0.txt_mlp.2.weight
# double_blocks.0.txt_mod.lin.bias
# double_blocks.0.txt_mod.lin.weight
# "transformer_blocks.0.attn.add_k_proj.bias",
# "transformer_blocks.0.attn.add_k_proj.weight",
# "transformer_blocks.0.attn.add_q_proj.bias",
# "transformer_blocks.0.attn.add_q_proj.weight",
# "transformer_blocks.0.attn.add_v_proj.bias",
# "transformer_blocks.0.attn.add_v_proj.weight",
# "transformer_blocks.0.attn.norm_added_k.weight",
# "transformer_blocks.0.attn.norm_added_q.weight",
# "transformer_blocks.0.attn.norm_k.weight",
# "transformer_blocks.0.attn.norm_q.weight",
# "transformer_blocks.0.attn.to_add_out.bias",
# "transformer_blocks.0.attn.to_add_out.weight",
# "transformer_blocks.0.attn.to_k.bias",
# "transformer_blocks.0.attn.to_k.weight",
# "transformer_blocks.0.attn.to_out.0.bias",
# "transformer_blocks.0.attn.to_out.0.weight",
# "transformer_blocks.0.attn.to_q.bias",
# "transformer_blocks.0.attn.to_q.weight",
# "transformer_blocks.0.attn.to_v.bias",
# "transformer_blocks.0.attn.to_v.weight",
# "transformer_blocks.0.ff.net.0.proj.bias",
# "transformer_blocks.0.ff.net.0.proj.weight",
# "transformer_blocks.0.ff.net.2.bias",
# "transformer_blocks.0.ff.net.2.weight",
# "transformer_blocks.0.ff_context.net.0.proj.bias",
# "transformer_blocks.0.ff_context.net.0.proj.weight",
# "transformer_blocks.0.ff_context.net.2.bias",
# "transformer_blocks.0.ff_context.net.2.weight",
# "transformer_blocks.0.norm1.linear.bias",
# "transformer_blocks.0.norm1.linear.weight",
# "transformer_blocks.0.norm1_context.linear.bias",
# "transformer_blocks.0.norm1_context.linear.weight",
to_prefix = f"double_blocks.{double_block_index}."
from_prefix = f"transformer_blocks.{double_block_index}."
new_sd: dict[str, torch.Tensor] = {}
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.bias"] = _fuse_qkv(
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.bias"),
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.bias"),
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.bias"),
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
return new_sd
# txt_attn.qkv
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
)
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.weight"] = _fuse_qkv(
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.weight"),
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.weight"),
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.weight"),
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
)
# img_attn.qkv
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
)
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# img_attn
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
# img_mlp
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
# img_mod
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
# txt_attn
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
# txt_mlp
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
# txt_mod
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
sd: Dict[str, torch.Tensor], single_block_index: int
) -> Dict[str, torch.Tensor]:
"""Convert the state dict for a single block from diffusers format to BFL format."""
to_prefix = f"single_blocks.{single_block_index}."
from_prefix = f"single_transformer_blocks.{single_block_index}."
new_sd: dict[str, torch.Tensor] = {}
# Check one key to determine if this block exists.
if f"{from_prefix}.attn.to_q.bias" not in sd:
return new_sd
# linear1 (qkv)
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.bias"),
sd.pop(f"{from_prefix}.attn.to_k.bias"),
sd.pop(f"{from_prefix}.attn.to_v.bias"),
sd.pop(f"{from_prefix}.proj_mlp.bias"),
)
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
sd.pop(f"{from_prefix}.attn.to_q.weight"),
sd.pop(f"{from_prefix}.attn.to_k.weight"),
sd.pop(f"{from_prefix}.attn.to_v.weight"),
sd.pop(f"{from_prefix}.proj_mlp.weight"),
)
# Handle basic 1-to-1 key conversions.
key_map = {
# linear2
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
# modulation
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
# norm
"attn.norm_k.weight": "norm.key_norm.scale",
"attn.norm_q.weight": "norm.query_norm.scale",
}
for from_key, to_key in key_map.items():
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
return new_sd
@@ -181,7 +223,13 @@ def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tens
block_index += 1
# Handle the single_blocks.
...
block_index = 0
while True:
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
if len(converted_singe_block_sd) == 0:
break
new_sd.update(converted_singe_block_sd)
block_index += 1
# Transfer controlnet keys as-is.
for k in sd: