mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 09:54:24 -05:00
Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_format(...).
This commit is contained in:
@@ -43,87 +43,129 @@ def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _fuse_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
||||
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
|
||||
"""Fuse weights along dimension 0.
|
||||
|
||||
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
|
||||
"""
|
||||
# TODO(ryand): Double check dim=0 is correct.
|
||||
return torch.cat((q, k, v), dim=0)
|
||||
return torch.cat(t, dim=0)
|
||||
|
||||
|
||||
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], double_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
||||
|
||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
||||
# double_blocks.0.img_attn.norm.query_norm.scale
|
||||
# double_blocks.0.img_attn.proj.bias
|
||||
# double_blocks.0.img_attn.proj.weight
|
||||
# double_blocks.0.img_attn.qkv.bias
|
||||
# double_blocks.0.img_attn.qkv.weight
|
||||
# double_blocks.0.img_mlp.0.bias
|
||||
# double_blocks.0.img_mlp.0.weight
|
||||
# double_blocks.0.img_mlp.2.bias
|
||||
# double_blocks.0.img_mlp.2.weight
|
||||
# double_blocks.0.img_mod.lin.bias
|
||||
# double_blocks.0.img_mod.lin.weight
|
||||
# double_blocks.0.txt_attn.norm.key_norm.scale
|
||||
# double_blocks.0.txt_attn.norm.query_norm.scale
|
||||
# double_blocks.0.txt_attn.proj.bias
|
||||
# double_blocks.0.txt_attn.proj.weight
|
||||
# double_blocks.0.txt_attn.qkv.bias
|
||||
# double_blocks.0.txt_attn.qkv.weight
|
||||
# double_blocks.0.txt_mlp.0.bias
|
||||
# double_blocks.0.txt_mlp.0.weight
|
||||
# double_blocks.0.txt_mlp.2.bias
|
||||
# double_blocks.0.txt_mlp.2.weight
|
||||
# double_blocks.0.txt_mod.lin.bias
|
||||
# double_blocks.0.txt_mod.lin.weight
|
||||
|
||||
# "transformer_blocks.0.attn.add_k_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_k_proj.weight",
|
||||
# "transformer_blocks.0.attn.add_q_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_q_proj.weight",
|
||||
# "transformer_blocks.0.attn.add_v_proj.bias",
|
||||
# "transformer_blocks.0.attn.add_v_proj.weight",
|
||||
# "transformer_blocks.0.attn.norm_added_k.weight",
|
||||
# "transformer_blocks.0.attn.norm_added_q.weight",
|
||||
# "transformer_blocks.0.attn.norm_k.weight",
|
||||
# "transformer_blocks.0.attn.norm_q.weight",
|
||||
# "transformer_blocks.0.attn.to_add_out.bias",
|
||||
# "transformer_blocks.0.attn.to_add_out.weight",
|
||||
# "transformer_blocks.0.attn.to_k.bias",
|
||||
# "transformer_blocks.0.attn.to_k.weight",
|
||||
# "transformer_blocks.0.attn.to_out.0.bias",
|
||||
# "transformer_blocks.0.attn.to_out.0.weight",
|
||||
# "transformer_blocks.0.attn.to_q.bias",
|
||||
# "transformer_blocks.0.attn.to_q.weight",
|
||||
# "transformer_blocks.0.attn.to_v.bias",
|
||||
# "transformer_blocks.0.attn.to_v.weight",
|
||||
# "transformer_blocks.0.ff.net.0.proj.bias",
|
||||
# "transformer_blocks.0.ff.net.0.proj.weight",
|
||||
# "transformer_blocks.0.ff.net.2.bias",
|
||||
# "transformer_blocks.0.ff.net.2.weight",
|
||||
# "transformer_blocks.0.ff_context.net.0.proj.bias",
|
||||
# "transformer_blocks.0.ff_context.net.0.proj.weight",
|
||||
# "transformer_blocks.0.ff_context.net.2.bias",
|
||||
# "transformer_blocks.0.ff_context.net.2.weight",
|
||||
# "transformer_blocks.0.norm1.linear.bias",
|
||||
# "transformer_blocks.0.norm1.linear.weight",
|
||||
# "transformer_blocks.0.norm1_context.linear.bias",
|
||||
# "transformer_blocks.0.norm1_context.linear.weight",
|
||||
to_prefix = f"double_blocks.{double_block_index}."
|
||||
from_prefix = f"transformer_blocks.{double_block_index}."
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.bias"] = _fuse_qkv(
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.bias"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.bias"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.bias"),
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# txt_attn.qkv
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
|
||||
)
|
||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.weight"] = _fuse_qkv(
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.weight"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.weight"),
|
||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.weight"),
|
||||
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
|
||||
)
|
||||
|
||||
# img_attn.qkv
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# img_attn
|
||||
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||
# img_mlp
|
||||
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||
"ff.net.2.weight": "img_mlp.2.weight",
|
||||
"ff.net.2.bias": "img_mlp.2.bias",
|
||||
# img_mod
|
||||
"norm1.linear.weight": "img_mod.lin.weight",
|
||||
"norm1.linear.bias": "img_mod.lin.bias",
|
||||
# txt_attn
|
||||
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||
# txt_mlp
|
||||
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||
# txt_mod
|
||||
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
|
||||
sd: Dict[str, torch.Tensor], single_block_index: int
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert the state dict for a single block from diffusers format to BFL format."""
|
||||
to_prefix = f"single_blocks.{single_block_index}."
|
||||
from_prefix = f"single_transformer_blocks.{single_block_index}."
|
||||
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Check one key to determine if this block exists.
|
||||
if f"{from_prefix}.attn.to_q.bias" not in sd:
|
||||
return new_sd
|
||||
|
||||
# linear1 (qkv)
|
||||
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.bias"),
|
||||
)
|
||||
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
|
||||
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||
sd.pop(f"{from_prefix}.proj_mlp.weight"),
|
||||
)
|
||||
|
||||
# Handle basic 1-to-1 key conversions.
|
||||
key_map = {
|
||||
# linear2
|
||||
"proj_out.weight": "linear2.weight",
|
||||
"proj_out.bias": "linear2.bias",
|
||||
# modulation
|
||||
"norm.linear.weight": "modulation.lin.weight",
|
||||
"norm.linear.bias": "modulation.lin.bias",
|
||||
# norm
|
||||
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||
}
|
||||
for from_key, to_key in key_map.items():
|
||||
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||
|
||||
return new_sd
|
||||
|
||||
|
||||
@@ -181,7 +223,13 @@ def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tens
|
||||
block_index += 1
|
||||
|
||||
# Handle the single_blocks.
|
||||
...
|
||||
block_index = 0
|
||||
while True:
|
||||
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||
if len(converted_singe_block_sd) == 0:
|
||||
break
|
||||
new_sd.update(converted_singe_block_sd)
|
||||
block_index += 1
|
||||
|
||||
# Transfer controlnet keys as-is.
|
||||
for k in sd:
|
||||
|
||||
Reference in New Issue
Block a user