diff --git a/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py index 77c395647d..18370b0289 100644 --- a/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_lora_conversion_utils.py @@ -1,67 +1,22 @@ +import re from typing import Dict import torch def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: - """Converts a state dict from the Kohya model to the InvokeAI model format. + """Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI. - - Example conversions: - ``` - "lora_unet_double_blocks_0_img_attn_proj.alpha": "double_blocks.0.img_attn.proj.alpha - "lora_unet_double_blocks_0_img_attn_proj.lora_down.weight": "double_blocks.0.img_attn.proj.lora_down.weight" - "lora_unet_double_blocks_0_img_attn_proj.lora_up.weight": "double_blocks.0.img_attn.proj.lora_up.weight" - "lora_unet_double_blocks_0_img_attn_qkv.alpha": "double_blocks.0.img_attn.qkv.alpha" - "lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight": "double_blocks.0.img.attn.qkv.lora_down.weight" - "lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight": "double_blocks.0.img.attn.qkv.lora_up.weight" - ``` + Example key conversions: + "lora_unet_double_blocks_0_img_attn_proj.alpha" -> "double_blocks.0.img_attn.proj.alpha + "lora_unet_double_blocks_0_img_attn_proj.lora_down.weight" -> "double_blocks.0.img_attn.proj.lora_down.weight" + "lora_unet_double_blocks_0_img_attn_proj.lora_up.weight" -> "double_blocks.0.img_attn.proj.lora_up.weight" + "lora_unet_double_blocks_0_img_attn_qkv.alpha" -> "double_blocks.0.img_attn.qkv.alpha" + "lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight" -> "double_blocks.0.img.attn.qkv.lora_down.weight" + "lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight" -> "double_blocks.0.img.attn.qkv.lora_up.weight" """ - new_sd: dict[str, torch.Tensor] = {} + pattern = r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" + replacement = r"\1.\2.\3.\4" - for k, v in state_dict.items(): - new_key = "" - - # Remove the lora_unet_ prefix. - k = k.replace("lora_unet_", "") - - # Split at the underscores. - parts = k.split("_") - - # Handle the block key (either "double_blocks" or "single_blocks") - new_key += "_".join(parts[:2]) - - # Handle the block index. - new_key += "." + parts[2] - - remaining_key = "_".join(parts[3:]) - - # Handle next module. - for module_name in [ - "img_attn", - "img_mlp", - "img_mod", - "txt_attn", - "txt_mlp", - "txt_mod", - "linear1", - "linear2", - "modulation", - ]: - if remaining_key.startswith(module_name): - new_key += "." + module_name - remaining_key = remaining_key.replace(module_name, "") - break - - # Handle the rest of the key. - while len(remaining_key) > 0: - next_chunk, remaining_key = remaining_key.split("_", 1) - if next_chunk.startswith("."): - new_key += next_chunk - else: - new_key += "." + next_chunk - - new_sd[new_key] = v - - return new_sd + return {re.sub(pattern, replacement, k): v for k, v in state_dict.items()} diff --git a/tests/backend/model_manager/data/lora_state_dicts/flux_lora_diffusers_format.py b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_diffusers_format.py similarity index 100% rename from tests/backend/model_manager/data/lora_state_dicts/flux_lora_diffusers_format.py rename to tests/backend/lora/conversions/lora_state_dicts/flux_lora_diffusers_format.py diff --git a/tests/backend/model_manager/data/lora_state_dicts/flux_lora_kohya_format.py b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_format.py similarity index 100% rename from tests/backend/model_manager/data/lora_state_dicts/flux_lora_kohya_format.py rename to tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_format.py diff --git a/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py new file mode 100644 index 0000000000..b72f1276f2 --- /dev/null +++ b/tests/backend/lora/conversions/test_flux_lora_conversion_utils.py @@ -0,0 +1,39 @@ +import torch + +from invokeai.backend.flux.model import Flux +from invokeai.backend.flux.util import params +from invokeai.backend.lora.conversions.flux_lora_conversion_utils import convert_flux_kohya_state_dict_to_invoke_format +from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import state_dict_keys + + +def test_convert_flux_kohya_state_dict_to_invoke_format(): + # Construct state_dict from state_dict_keys. + state_dict: dict[str, torch.Tensor] = {} + for k in state_dict_keys: + state_dict[k] = torch.empty(1) + + converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict) + + # Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and + # .alpha suffixes). + converted_key_prefixes: list[str] = [] + for k in converted_state_dict.keys(): + k = k.replace(".lora_up.weight", "") + k = k.replace(".lora_down.weight", "") + k = k.replace(".alpha", "") + converted_key_prefixes.append(k) + + # Initialize a FLUX model on the meta device. + with torch.device("meta"): + model = Flux(params["flux-dev"]) + model_keys = set(model.state_dict().keys()) + + # Assert that the converted state dict matches the keys in the actual model. + for converted_key_prefix in converted_key_prefixes: + found_match = False + for model_key in model_keys: + if model_key.startswith(converted_key_prefix): + found_match = True + break + if not found_match: + raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}")