Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests.

This commit is contained in:
Ryan Dick
2024-09-04 13:42:12 +00:00
committed by Kent Keirsey
parent c41bd59812
commit ade75b4748
4 changed files with 51 additions and 57 deletions

View File

@@ -1,67 +1,22 @@
import re
from typing import Dict
import torch
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Converts a state dict from the Kohya model to the InvokeAI model format.
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
Example conversions:
```
"lora_unet_double_blocks_0_img_attn_proj.alpha": "double_blocks.0.img_attn.proj.alpha
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight": "double_blocks.0.img_attn.proj.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight": "double_blocks.0.img_attn.proj.lora_up.weight"
"lora_unet_double_blocks_0_img_attn_qkv.alpha": "double_blocks.0.img_attn.qkv.alpha"
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight": "double_blocks.0.img.attn.qkv.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight": "double_blocks.0.img.attn.qkv.lora_up.weight"
```
Example key conversions:
"lora_unet_double_blocks_0_img_attn_proj.alpha" -> "double_blocks.0.img_attn.proj.alpha
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight" -> "double_blocks.0.img_attn.proj.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight" -> "double_blocks.0.img_attn.proj.lora_up.weight"
"lora_unet_double_blocks_0_img_attn_qkv.alpha" -> "double_blocks.0.img_attn.qkv.alpha"
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight" -> "double_blocks.0.img.attn.qkv.lora_down.weight"
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight" -> "double_blocks.0.img.attn.qkv.lora_up.weight"
"""
new_sd: dict[str, torch.Tensor] = {}
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
replacement = r"\1.\2.\3.\4"
for k, v in state_dict.items():
new_key = ""
# Remove the lora_unet_ prefix.
k = k.replace("lora_unet_", "")
# Split at the underscores.
parts = k.split("_")
# Handle the block key (either "double_blocks" or "single_blocks")
new_key += "_".join(parts[:2])
# Handle the block index.
new_key += "." + parts[2]
remaining_key = "_".join(parts[3:])
# Handle next module.
for module_name in [
"img_attn",
"img_mlp",
"img_mod",
"txt_attn",
"txt_mlp",
"txt_mod",
"linear1",
"linear2",
"modulation",
]:
if remaining_key.startswith(module_name):
new_key += "." + module_name
remaining_key = remaining_key.replace(module_name, "")
break
# Handle the rest of the key.
while len(remaining_key) > 0:
next_chunk, remaining_key = remaining_key.split("_", 1)
if next_chunk.startswith("."):
new_key += next_chunk
else:
new_key += "." + next_chunk
new_sd[new_key] = v
return new_sd
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}

View File

@@ -0,0 +1,39 @@
import torch
from invokeai.backend.flux.model import Flux
from invokeai.backend.flux.util import params
from invokeai.backend.lora.conversions.flux_lora_conversion_utils import convert_flux_kohya_state_dict_to_invoke_format
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
def test_convert_flux_kohya_state_dict_to_invoke_format():
# Construct state_dict from state_dict_keys.
state_dict: dict[str, torch.Tensor] = {}
for k in state_dict_keys:
state_dict[k] = torch.empty(1)
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
# .alpha suffixes).
converted_key_prefixes: list[str] = []
for k in converted_state_dict.keys():
k = k.replace(".lora_up.weight", "")
k = k.replace(".lora_down.weight", "")
k = k.replace(".alpha", "")
converted_key_prefixes.append(k)
# Initialize a FLUX model on the meta device.
with torch.device("meta"):
model = Flux(params["flux-dev"])
model_keys = set(model.state_dict().keys())
# Assert that the converted state dict matches the keys in the actual model.
for converted_key_prefix in converted_key_prefixes:
found_match = False
for model_key in model_keys:
if model_key.startswith(converted_key_prefix):
found_match = True
break
if not found_match:
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")