mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Finish first draft of convert_diffusers_instantx_state_dict_to_bfl_format(...).
This commit is contained in:
@@ -43,87 +43,129 @@ def is_state_dict_instantx_controlnet(sd: Dict[str, Any]) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _fuse_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
def _fuse_weights(*t: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Fuse weights along dimension 0.
|
||||||
|
|
||||||
|
Used to fuse q, k, v attention weights into a single qkv tensor when converting from diffusers to BFL format.
|
||||||
|
"""
|
||||||
# TODO(ryand): Double check dim=0 is correct.
|
# TODO(ryand): Double check dim=0 is correct.
|
||||||
return torch.cat((q, k, v), dim=0)
|
return torch.cat(t, dim=0)
|
||||||
|
|
||||||
|
|
||||||
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
def _convert_flux_double_block_sd_from_diffusers_to_bfl_format(
|
||||||
sd: Dict[str, torch.Tensor], double_block_index: int
|
sd: Dict[str, torch.Tensor], double_block_index: int
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, torch.Tensor]:
|
||||||
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
"""Convert the state dict for a double block from diffusers format to BFL format."""
|
||||||
|
to_prefix = f"double_blocks.{double_block_index}."
|
||||||
# double_blocks.0.img_attn.norm.key_norm.scale
|
from_prefix = f"transformer_blocks.{double_block_index}."
|
||||||
# double_blocks.0.img_attn.norm.query_norm.scale
|
|
||||||
# double_blocks.0.img_attn.proj.bias
|
|
||||||
# double_blocks.0.img_attn.proj.weight
|
|
||||||
# double_blocks.0.img_attn.qkv.bias
|
|
||||||
# double_blocks.0.img_attn.qkv.weight
|
|
||||||
# double_blocks.0.img_mlp.0.bias
|
|
||||||
# double_blocks.0.img_mlp.0.weight
|
|
||||||
# double_blocks.0.img_mlp.2.bias
|
|
||||||
# double_blocks.0.img_mlp.2.weight
|
|
||||||
# double_blocks.0.img_mod.lin.bias
|
|
||||||
# double_blocks.0.img_mod.lin.weight
|
|
||||||
# double_blocks.0.txt_attn.norm.key_norm.scale
|
|
||||||
# double_blocks.0.txt_attn.norm.query_norm.scale
|
|
||||||
# double_blocks.0.txt_attn.proj.bias
|
|
||||||
# double_blocks.0.txt_attn.proj.weight
|
|
||||||
# double_blocks.0.txt_attn.qkv.bias
|
|
||||||
# double_blocks.0.txt_attn.qkv.weight
|
|
||||||
# double_blocks.0.txt_mlp.0.bias
|
|
||||||
# double_blocks.0.txt_mlp.0.weight
|
|
||||||
# double_blocks.0.txt_mlp.2.bias
|
|
||||||
# double_blocks.0.txt_mlp.2.weight
|
|
||||||
# double_blocks.0.txt_mod.lin.bias
|
|
||||||
# double_blocks.0.txt_mod.lin.weight
|
|
||||||
|
|
||||||
# "transformer_blocks.0.attn.add_k_proj.bias",
|
|
||||||
# "transformer_blocks.0.attn.add_k_proj.weight",
|
|
||||||
# "transformer_blocks.0.attn.add_q_proj.bias",
|
|
||||||
# "transformer_blocks.0.attn.add_q_proj.weight",
|
|
||||||
# "transformer_blocks.0.attn.add_v_proj.bias",
|
|
||||||
# "transformer_blocks.0.attn.add_v_proj.weight",
|
|
||||||
# "transformer_blocks.0.attn.norm_added_k.weight",
|
|
||||||
# "transformer_blocks.0.attn.norm_added_q.weight",
|
|
||||||
# "transformer_blocks.0.attn.norm_k.weight",
|
|
||||||
# "transformer_blocks.0.attn.norm_q.weight",
|
|
||||||
# "transformer_blocks.0.attn.to_add_out.bias",
|
|
||||||
# "transformer_blocks.0.attn.to_add_out.weight",
|
|
||||||
# "transformer_blocks.0.attn.to_k.bias",
|
|
||||||
# "transformer_blocks.0.attn.to_k.weight",
|
|
||||||
# "transformer_blocks.0.attn.to_out.0.bias",
|
|
||||||
# "transformer_blocks.0.attn.to_out.0.weight",
|
|
||||||
# "transformer_blocks.0.attn.to_q.bias",
|
|
||||||
# "transformer_blocks.0.attn.to_q.weight",
|
|
||||||
# "transformer_blocks.0.attn.to_v.bias",
|
|
||||||
# "transformer_blocks.0.attn.to_v.weight",
|
|
||||||
# "transformer_blocks.0.ff.net.0.proj.bias",
|
|
||||||
# "transformer_blocks.0.ff.net.0.proj.weight",
|
|
||||||
# "transformer_blocks.0.ff.net.2.bias",
|
|
||||||
# "transformer_blocks.0.ff.net.2.weight",
|
|
||||||
# "transformer_blocks.0.ff_context.net.0.proj.bias",
|
|
||||||
# "transformer_blocks.0.ff_context.net.0.proj.weight",
|
|
||||||
# "transformer_blocks.0.ff_context.net.2.bias",
|
|
||||||
# "transformer_blocks.0.ff_context.net.2.weight",
|
|
||||||
# "transformer_blocks.0.norm1.linear.bias",
|
|
||||||
# "transformer_blocks.0.norm1.linear.weight",
|
|
||||||
# "transformer_blocks.0.norm1_context.linear.bias",
|
|
||||||
# "transformer_blocks.0.norm1_context.linear.weight",
|
|
||||||
|
|
||||||
new_sd: dict[str, torch.Tensor] = {}
|
new_sd: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.bias"] = _fuse_qkv(
|
# Check one key to determine if this block exists.
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.bias"),
|
if f"{from_prefix}.attn.add_q_proj.bias" not in sd:
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.bias"),
|
return new_sd
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.bias"),
|
|
||||||
|
# txt_attn.qkv
|
||||||
|
new_sd[f"{to_prefix}.txt_attn.qkv.bias"] = _fuse_weights(
|
||||||
|
sd.pop(f"{from_prefix}.attn.add_q_proj.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.add_k_proj.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.add_v_proj.bias"),
|
||||||
)
|
)
|
||||||
new_sd[f"double_blocks.{double_block_index}.txt_attn.qkv.weight"] = _fuse_qkv(
|
new_sd[f"{to_prefix}.txt_attn.qkv.weight"] = _fuse_weights(
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_q_proj.weight"),
|
sd.pop(f"{from_prefix}.attn.add_q_proj.weight"),
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_k_proj.weight"),
|
sd.pop(f"{from_prefix}.attn.add_k_proj.weight"),
|
||||||
sd.pop(f"transformer_blocks.{double_block_index}.attn.add_v_proj.weight"),
|
sd.pop(f"{from_prefix}.attn.add_v_proj.weight"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# img_attn.qkv
|
||||||
|
new_sd[f"{to_prefix}.img_attn.qkv.bias"] = _fuse_weights(
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||||
|
)
|
||||||
|
new_sd[f"{to_prefix}.img_attn.qkv.weight"] = _fuse_weights(
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle basic 1-to-1 key conversions.
|
||||||
|
key_map = {
|
||||||
|
# img_attn
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||||
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||||
|
# img_mlp
|
||||||
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||||
|
"ff.net.2.weight": "img_mlp.2.weight",
|
||||||
|
"ff.net.2.bias": "img_mlp.2.bias",
|
||||||
|
# img_mod
|
||||||
|
"norm1.linear.weight": "img_mod.lin.weight",
|
||||||
|
"norm1.linear.bias": "img_mod.lin.bias",
|
||||||
|
# txt_attn
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||||
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||||
|
# txt_mlp
|
||||||
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||||
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
# txt_mod
|
||||||
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||||
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||||
|
}
|
||||||
|
for from_key, to_key in key_map.items():
|
||||||
|
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||||
|
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_flux_single_block_sd_from_diffusers_to_bfl_format(
|
||||||
|
sd: Dict[str, torch.Tensor], single_block_index: int
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""Convert the state dict for a single block from diffusers format to BFL format."""
|
||||||
|
to_prefix = f"single_blocks.{single_block_index}."
|
||||||
|
from_prefix = f"single_transformer_blocks.{single_block_index}."
|
||||||
|
|
||||||
|
new_sd: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Check one key to determine if this block exists.
|
||||||
|
if f"{from_prefix}.attn.to_q.bias" not in sd:
|
||||||
|
return new_sd
|
||||||
|
|
||||||
|
# linear1 (qkv)
|
||||||
|
new_sd[f"{to_prefix}.linear1.bias"] = _fuse_weights(
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_q.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_k.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_v.bias"),
|
||||||
|
sd.pop(f"{from_prefix}.proj_mlp.bias"),
|
||||||
|
)
|
||||||
|
new_sd[f"{to_prefix}.linear1.weight"] = _fuse_weights(
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_q.weight"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_k.weight"),
|
||||||
|
sd.pop(f"{from_prefix}.attn.to_v.weight"),
|
||||||
|
sd.pop(f"{from_prefix}.proj_mlp.weight"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle basic 1-to-1 key conversions.
|
||||||
|
key_map = {
|
||||||
|
# linear2
|
||||||
|
"proj_out.weight": "linear2.weight",
|
||||||
|
"proj_out.bias": "linear2.bias",
|
||||||
|
# modulation
|
||||||
|
"norm.linear.weight": "modulation.lin.weight",
|
||||||
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
|
# norm
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
}
|
||||||
|
for from_key, to_key in key_map.items():
|
||||||
|
new_sd[f"{to_prefix}.{to_key}"] = sd.pop(f"{from_prefix}.{from_key}")
|
||||||
|
|
||||||
return new_sd
|
return new_sd
|
||||||
|
|
||||||
|
|
||||||
@@ -181,7 +223,13 @@ def convert_diffusers_instantx_state_dict_to_bfl_format(sd: Dict[str, torch.Tens
|
|||||||
block_index += 1
|
block_index += 1
|
||||||
|
|
||||||
# Handle the single_blocks.
|
# Handle the single_blocks.
|
||||||
...
|
block_index = 0
|
||||||
|
while True:
|
||||||
|
converted_singe_block_sd = _convert_flux_single_block_sd_from_diffusers_to_bfl_format(sd, block_index)
|
||||||
|
if len(converted_singe_block_sd) == 0:
|
||||||
|
break
|
||||||
|
new_sd.update(converted_singe_block_sd)
|
||||||
|
block_index += 1
|
||||||
|
|
||||||
# Transfer controlnet keys as-is.
|
# Transfer controlnet keys as-is.
|
||||||
for k in sd:
|
for k in sd:
|
||||||
|
|||||||
Reference in New Issue
Block a user