mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 18:38:25 -05:00
Get convert_flux_kohya_state_dict_to_invoke_format(...) working, with unit tests.
This commit is contained in:
@@ -1,67 +1,22 @@
|
||||
import re
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||
"""Converts a state dict from the Kohya model to the InvokeAI model format.
|
||||
"""Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI.
|
||||
|
||||
|
||||
Example conversions:
|
||||
```
|
||||
"lora_unet_double_blocks_0_img_attn_proj.alpha": "double_blocks.0.img_attn.proj.alpha
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight": "double_blocks.0.img_attn.proj.lora_down.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight": "double_blocks.0.img_attn.proj.lora_up.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.alpha": "double_blocks.0.img_attn.qkv.alpha"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight": "double_blocks.0.img.attn.qkv.lora_down.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight": "double_blocks.0.img.attn.qkv.lora_up.weight"
|
||||
```
|
||||
Example key conversions:
|
||||
"lora_unet_double_blocks_0_img_attn_proj.alpha" -> "double_blocks.0.img_attn.proj.alpha
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_down.weight" -> "double_blocks.0.img_attn.proj.lora_down.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_proj.lora_up.weight" -> "double_blocks.0.img_attn.proj.lora_up.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.alpha" -> "double_blocks.0.img_attn.qkv.alpha"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight" -> "double_blocks.0.img.attn.qkv.lora_down.weight"
|
||||
"lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight" -> "double_blocks.0.img.attn.qkv.lora_up.weight"
|
||||
|
||||
"""
|
||||
new_sd: dict[str, torch.Tensor] = {}
|
||||
pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)"
|
||||
replacement = r"\1.\2.\3.\4"
|
||||
|
||||
for k, v in state_dict.items():
|
||||
new_key = ""
|
||||
|
||||
# Remove the lora_unet_ prefix.
|
||||
k = k.replace("lora_unet_", "")
|
||||
|
||||
# Split at the underscores.
|
||||
parts = k.split("_")
|
||||
|
||||
# Handle the block key (either "double_blocks" or "single_blocks")
|
||||
new_key += "_".join(parts[:2])
|
||||
|
||||
# Handle the block index.
|
||||
new_key += "." + parts[2]
|
||||
|
||||
remaining_key = "_".join(parts[3:])
|
||||
|
||||
# Handle next module.
|
||||
for module_name in [
|
||||
"img_attn",
|
||||
"img_mlp",
|
||||
"img_mod",
|
||||
"txt_attn",
|
||||
"txt_mlp",
|
||||
"txt_mod",
|
||||
"linear1",
|
||||
"linear2",
|
||||
"modulation",
|
||||
]:
|
||||
if remaining_key.startswith(module_name):
|
||||
new_key += "." + module_name
|
||||
remaining_key = remaining_key.replace(module_name, "")
|
||||
break
|
||||
|
||||
# Handle the rest of the key.
|
||||
while len(remaining_key) > 0:
|
||||
next_chunk, remaining_key = remaining_key.split("_", 1)
|
||||
if next_chunk.startswith("."):
|
||||
new_key += next_chunk
|
||||
else:
|
||||
new_key += "." + next_chunk
|
||||
|
||||
new_sd[new_key] = v
|
||||
|
||||
return new_sd
|
||||
return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()}
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.flux.model import Flux
|
||||
from invokeai.backend.flux.util import params
|
||||
from invokeai.backend.lora.conversions.flux_lora_conversion_utils import convert_flux_kohya_state_dict_to_invoke_format
|
||||
from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys
|
||||
|
||||
|
||||
def test_convert_flux_kohya_state_dict_to_invoke_format():
|
||||
# Construct state_dict from state_dict_keys.
|
||||
state_dict: dict[str, torch.Tensor] = {}
|
||||
for k in state_dict_keys:
|
||||
state_dict[k] = torch.empty(1)
|
||||
|
||||
converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict)
|
||||
|
||||
# Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and
|
||||
# .alpha suffixes).
|
||||
converted_key_prefixes: list[str] = []
|
||||
for k in converted_state_dict.keys():
|
||||
k = k.replace(".lora_up.weight", "")
|
||||
k = k.replace(".lora_down.weight", "")
|
||||
k = k.replace(".alpha", "")
|
||||
converted_key_prefixes.append(k)
|
||||
|
||||
# Initialize a FLUX model on the meta device.
|
||||
with torch.device("meta"):
|
||||
model = Flux(params["flux-dev"])
|
||||
model_keys = set(model.state_dict().keys())
|
||||
|
||||
# Assert that the converted state dict matches the keys in the actual model.
|
||||
for converted_key_prefix in converted_key_prefixes:
|
||||
found_match = False
|
||||
for model_key in model_keys:
|
||||
if model_key.startswith(converted_key_prefix):
|
||||
found_match = True
|
||||
break
|
||||
if not found_match:
|
||||
raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")
|
||||
Reference in New Issue
Block a user