Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests.

This commit is contained in:
Ryan Dick
2024-09-04 13:42:12 +00:00
committed by Kent Keirsey
parent c41bd59812
commit ade75b4748
4 changed files with 51 additions and 57 deletions

View File

@@ -1,67 +1,22 @@
import re
from typing import Dict
import torch
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converts a state dict from the Kohya model to the InvokeAI model format.
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example conversions:
```
"lora_unet_double_blocks_0_img_attn_proj.alpha": "double_blocks.0.img_attn.proj.alpha
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight": "double_blocks.0.img_attn.proj.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight": "double_blocks.0.img_attn.proj.lora_up.weight"
"lora_unet_double_blocks_0_img_attn_qkv.alpha": "double_blocks.0.img_attn.qkv.alpha"
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight": "double_blocks.0.img.attn.qkv.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight": "double_blocks.0.img.attn.qkv.lora_up.weight"
```
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj.alpha" -> "double_blocks.0.img_attn.proj.alpha
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight" -> "double_blocks.0.img_attn.proj.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight" -> "double_blocks.0.img_attn.proj.lora_up.weight"
"lora_unet_double_blocks_0_img_attn_qkv.alpha" -> "double_blocks.0.img_attn.qkv.alpha"
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight" -> "double_blocks.0.img.attn.qkv.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight" -> "double_blocks.0.img.attn.qkv.lora_up.weight"
"""
new_sd: dict[str, torch.Tensor] = {}
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
replacement = r"\1.\2.\3.\4"
for k, v in state_dict.items():
new_key = ""
# Remove the lora_unet_ prefix.
k = k.replace("lora_unet_", "")
# Split at the underscores.
parts = k.split("_")
# Handle the block key (either "double_blocks" or "single_blocks")
new_key += "_".join(parts[:2])
# Handle the block index.
new_key += "." + parts[2]
remaining_key = "_".join(parts[3:])
# Handle next module.
for module_name in [
"img_attn",
"img_mlp",
"img_mod",
"txt_attn",
"txt_mlp",
"txt_mod",
"linear1",
"linear2",
"modulation",
]:
if remaining_key.startswith(module_name):
new_key += "." + module_name
remaining_key = remaining_key.replace(module_name, "")
break
# Handle the rest of the key.
while len(remaining_key) > 0:
next_chunk, remaining_key = remaining_key.split("_", 1)
if next_chunk.startswith("."):
new_key += next_chunk
else:
new_key += "." + next_chunk
new_sd[new_key] = v
return new_sd
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}