From d332d81866b90054fc1e2002cc877b7cc60fb6ee Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Sep 2024 19:41:56 +0000 Subject: [PATCH 1/7] Add ability to load FLUX kohya LoRA models that include patches for both the transformer and T5 models. --- .../flux_kohya_lora_conversion_utils.py | 66 +- .../flux_lora_kohya_with_te1_format.py | 1133 +++++++++++++++++ .../test_flux_kohya_lora_conversion_utils.py | 32 +- 3 files changed, 1205 insertions(+), 26 deletions(-) create mode 100644 tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py diff --git a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py index 3e1ccf6493..83e61384b9 100644 --- a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py @@ -1,3 +1,4 @@ +import itertools import re from typing import Any, Dict, TypeVar @@ -7,14 +8,20 @@ from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer from invokeai.backend.lora.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.lora.lora_model_raw import LoRAModelRaw -# A regex pattern that matches all of the keys in the Kohya FLUX LoRA format. +# A regex pattern that matches all of the transformer keys in the Kohya FLUX LoRA format. # Example keys: # lora_unet_double_blocks_0_img_attn_proj.alpha # lora_unet_double_blocks_0_img_attn_proj.lora_down.weight # lora_unet_double_blocks_0_img_attn_proj.lora_up.weight -FLUX_KOHYA_KEY_REGEX = ( +FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" ) +# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format. +# Example keys: +# lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha +# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight +# lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight +FLUX_KOHYA_T5_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: @@ -23,7 +30,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ - return all(re.match(FLUX_KOHYA_KEY_REGEX, k) for k in state_dict.keys()) + return all( + re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + ) def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -> LoRAModelRaw: @@ -35,12 +44,24 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - grouped_state_dict[layer_name] = {} grouped_state_dict[layer_name][param_name] = value - # Convert the state dict to the InvokeAI format. - grouped_state_dict = convert_flux_kohya_state_dict_to_invoke_format(grouped_state_dict) + # Split the grouped state dict into transformer and T5 state dicts. + transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + for layer_name, layer_state_dict in grouped_state_dict.items(): + if layer_name.startswith("lora_unet"): + transformer_grouped_sd[layer_name] = layer_state_dict + elif layer_name.startswith("lora_te1"): + t5_grouped_sd[layer_name] = layer_state_dict + else: + raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.") + + # Convert the state dicts to the InvokeAI format. + transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd) + t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd) # Create LoRA layers. layers: dict[str, AnyLoRALayer] = {} - for layer_key, layer_state_dict in grouped_state_dict.items(): + for layer_key, layer_state_dict in itertools.chain(transformer_grouped_sd.items(), t5_grouped_sd.items()): layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict) # Create and return the LoRAModelRaw. @@ -50,16 +71,33 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - T = TypeVar("T") -def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: - """Converts a state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI. +def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: + """Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI. + + Example key conversions: + + "lora_te1_text_model_encoder_layers_0_mlp_fc1" -> "text_model.encoder.layers.0.mlp.fc1", + "lora_te1_text_model_encoder_layers_0_self_attn_k_proj" -> "text_model.encoder.layers.0.self_attn.k_proj" + """ + converted_sd: dict[str, T] = {} + for k, v in state_dict.items(): + match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k) + if match: + new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}" + converted_sd[new_key] = v + else: + raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") + + return converted_sd + + +def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: + """Converts a FLUX tranformer LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally + by InvokeAI. Example key conversions: "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj" - "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj" - "lora_unet_double_blocks_0_img_attn_proj" -> "double_blocks.0.img_attn.proj" "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img_attn.qkv" - "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv" - "lora_unet_double_blocks_0_img_attn_qkv" -> "double_blocks.0.img.attn.qkv" """ def replace_func(match: re.Match[str]) -> str: @@ -70,9 +108,9 @@ def convert_flux_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> converted_dict: dict[str, T] = {} for k, v in state_dict.items(): - match = re.match(FLUX_KOHYA_KEY_REGEX, k) + match = re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) if match: - new_key = re.sub(FLUX_KOHYA_KEY_REGEX, replace_func, k) + new_key = re.sub(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, replace_func, k) converted_dict[new_key] = v else: raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") diff --git a/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py new file mode 100644 index 0000000000..c43505e9c0 --- /dev/null +++ b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py @@ -0,0 +1,1133 @@ +# A sample state dict in the Kohya FLUX LoRA format that patches both the transformer and T5 text encoder. +# These keys are based on the LoRA model here: +# https://huggingface.co/cocktailpeanut/optimus +state_dict_keys = [ + "lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_0_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_0_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_0_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_0_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_0_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_0_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_10_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_10_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_10_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_10_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_10_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_10_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_10_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_11_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_11_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_11_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_11_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_11_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_11_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_11_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_1_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_1_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_1_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_1_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_1_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_1_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_1_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_2_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_2_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_2_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_2_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_2_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_2_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_2_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_3_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_3_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_3_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_3_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_3_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_3_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_3_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_4_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_4_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_4_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_4_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_4_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_4_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_4_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_5_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_5_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_5_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_5_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_5_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_5_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_5_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_6_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_6_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_6_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_6_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_6_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_6_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_6_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_7_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_7_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_7_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_7_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_7_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_7_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_7_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_8_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_8_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_8_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_8_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_8_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_8_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_8_self_attn_v_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_mlp_fc1.alpha", + "lora_te1_text_model_encoder_layers_9_mlp_fc1.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_mlp_fc1.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_mlp_fc2.alpha", + "lora_te1_text_model_encoder_layers_9_mlp_fc2.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_mlp_fc2.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_k_proj.alpha", + "lora_te1_text_model_encoder_layers_9_self_attn_k_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_k_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_out_proj.alpha", + "lora_te1_text_model_encoder_layers_9_self_attn_out_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_out_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_q_proj.alpha", + "lora_te1_text_model_encoder_layers_9_self_attn_q_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_q_proj.lora_up.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_v_proj.alpha", + "lora_te1_text_model_encoder_layers_9_self_attn_v_proj.lora_down.weight", + "lora_te1_text_model_encoder_layers_9_self_attn_v_proj.lora_up.weight", + "lora_unet_double_blocks_0_img_attn_proj.alpha", + "lora_unet_double_blocks_0_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_0_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_0_img_attn_qkv.alpha", + "lora_unet_double_blocks_0_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_0_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_0_img_mlp_0.alpha", + "lora_unet_double_blocks_0_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_0_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_0_img_mlp_2.alpha", + "lora_unet_double_blocks_0_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_0_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_0_img_mod_lin.alpha", + "lora_unet_double_blocks_0_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_0_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_0_txt_attn_proj.alpha", + "lora_unet_double_blocks_0_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_0_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_0_txt_attn_qkv.alpha", + "lora_unet_double_blocks_0_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_0_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_0_txt_mlp_0.alpha", + "lora_unet_double_blocks_0_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_0_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_0_txt_mlp_2.alpha", + "lora_unet_double_blocks_0_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_0_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_0_txt_mod_lin.alpha", + "lora_unet_double_blocks_0_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_0_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_10_img_attn_proj.alpha", + "lora_unet_double_blocks_10_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_10_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_10_img_attn_qkv.alpha", + "lora_unet_double_blocks_10_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_10_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_10_img_mlp_0.alpha", + "lora_unet_double_blocks_10_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_10_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_10_img_mlp_2.alpha", + "lora_unet_double_blocks_10_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_10_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_10_img_mod_lin.alpha", + "lora_unet_double_blocks_10_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_10_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_10_txt_attn_proj.alpha", + "lora_unet_double_blocks_10_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_10_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_10_txt_attn_qkv.alpha", + "lora_unet_double_blocks_10_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_10_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_10_txt_mlp_0.alpha", + "lora_unet_double_blocks_10_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_10_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_10_txt_mlp_2.alpha", + "lora_unet_double_blocks_10_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_10_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_10_txt_mod_lin.alpha", + "lora_unet_double_blocks_10_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_10_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_11_img_attn_proj.alpha", + "lora_unet_double_blocks_11_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_11_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_11_img_attn_qkv.alpha", + "lora_unet_double_blocks_11_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_11_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_11_img_mlp_0.alpha", + "lora_unet_double_blocks_11_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_11_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_11_img_mlp_2.alpha", + "lora_unet_double_blocks_11_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_11_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_11_img_mod_lin.alpha", + "lora_unet_double_blocks_11_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_11_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_11_txt_attn_proj.alpha", + "lora_unet_double_blocks_11_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_11_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_11_txt_attn_qkv.alpha", + "lora_unet_double_blocks_11_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_11_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_11_txt_mlp_0.alpha", + "lora_unet_double_blocks_11_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_11_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_11_txt_mlp_2.alpha", + "lora_unet_double_blocks_11_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_11_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_11_txt_mod_lin.alpha", + "lora_unet_double_blocks_11_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_11_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_12_img_attn_proj.alpha", + "lora_unet_double_blocks_12_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_12_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_12_img_attn_qkv.alpha", + "lora_unet_double_blocks_12_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_12_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_12_img_mlp_0.alpha", + "lora_unet_double_blocks_12_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_12_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_12_img_mlp_2.alpha", + "lora_unet_double_blocks_12_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_12_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_12_img_mod_lin.alpha", + "lora_unet_double_blocks_12_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_12_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_12_txt_attn_proj.alpha", + "lora_unet_double_blocks_12_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_12_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_12_txt_attn_qkv.alpha", + "lora_unet_double_blocks_12_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_12_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_12_txt_mlp_0.alpha", + "lora_unet_double_blocks_12_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_12_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_12_txt_mlp_2.alpha", + "lora_unet_double_blocks_12_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_12_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_12_txt_mod_lin.alpha", + "lora_unet_double_blocks_12_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_12_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_13_img_attn_proj.alpha", + "lora_unet_double_blocks_13_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_13_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_13_img_attn_qkv.alpha", + "lora_unet_double_blocks_13_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_13_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_13_img_mlp_0.alpha", + "lora_unet_double_blocks_13_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_13_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_13_img_mlp_2.alpha", + "lora_unet_double_blocks_13_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_13_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_13_img_mod_lin.alpha", + "lora_unet_double_blocks_13_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_13_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_13_txt_attn_proj.alpha", + "lora_unet_double_blocks_13_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_13_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_13_txt_attn_qkv.alpha", + "lora_unet_double_blocks_13_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_13_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_13_txt_mlp_0.alpha", + "lora_unet_double_blocks_13_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_13_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_13_txt_mlp_2.alpha", + "lora_unet_double_blocks_13_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_13_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_13_txt_mod_lin.alpha", + "lora_unet_double_blocks_13_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_13_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_14_img_attn_proj.alpha", + "lora_unet_double_blocks_14_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_14_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_14_img_attn_qkv.alpha", + "lora_unet_double_blocks_14_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_14_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_14_img_mlp_0.alpha", + "lora_unet_double_blocks_14_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_14_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_14_img_mlp_2.alpha", + "lora_unet_double_blocks_14_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_14_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_14_img_mod_lin.alpha", + "lora_unet_double_blocks_14_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_14_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_14_txt_attn_proj.alpha", + "lora_unet_double_blocks_14_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_14_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_14_txt_attn_qkv.alpha", + "lora_unet_double_blocks_14_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_14_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_14_txt_mlp_0.alpha", + "lora_unet_double_blocks_14_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_14_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_14_txt_mlp_2.alpha", + "lora_unet_double_blocks_14_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_14_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_14_txt_mod_lin.alpha", + "lora_unet_double_blocks_14_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_14_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_15_img_attn_proj.alpha", + "lora_unet_double_blocks_15_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_15_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_15_img_attn_qkv.alpha", + "lora_unet_double_blocks_15_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_15_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_15_img_mlp_0.alpha", + "lora_unet_double_blocks_15_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_15_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_15_img_mlp_2.alpha", + "lora_unet_double_blocks_15_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_15_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_15_img_mod_lin.alpha", + "lora_unet_double_blocks_15_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_15_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_15_txt_attn_proj.alpha", + "lora_unet_double_blocks_15_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_15_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_15_txt_attn_qkv.alpha", + "lora_unet_double_blocks_15_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_15_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_15_txt_mlp_0.alpha", + "lora_unet_double_blocks_15_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_15_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_15_txt_mlp_2.alpha", + "lora_unet_double_blocks_15_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_15_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_15_txt_mod_lin.alpha", + "lora_unet_double_blocks_15_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_15_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_16_img_attn_proj.alpha", + "lora_unet_double_blocks_16_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_16_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_16_img_attn_qkv.alpha", + "lora_unet_double_blocks_16_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_16_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_16_img_mlp_0.alpha", + "lora_unet_double_blocks_16_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_16_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_16_img_mlp_2.alpha", + "lora_unet_double_blocks_16_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_16_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_16_img_mod_lin.alpha", + "lora_unet_double_blocks_16_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_16_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_16_txt_attn_proj.alpha", + "lora_unet_double_blocks_16_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_16_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_16_txt_attn_qkv.alpha", + "lora_unet_double_blocks_16_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_16_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_16_txt_mlp_0.alpha", + "lora_unet_double_blocks_16_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_16_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_16_txt_mlp_2.alpha", + "lora_unet_double_blocks_16_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_16_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_16_txt_mod_lin.alpha", + "lora_unet_double_blocks_16_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_16_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_17_img_attn_proj.alpha", + "lora_unet_double_blocks_17_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_17_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_17_img_attn_qkv.alpha", + "lora_unet_double_blocks_17_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_17_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_17_img_mlp_0.alpha", + "lora_unet_double_blocks_17_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_17_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_17_img_mlp_2.alpha", + "lora_unet_double_blocks_17_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_17_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_17_img_mod_lin.alpha", + "lora_unet_double_blocks_17_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_17_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_17_txt_attn_proj.alpha", + "lora_unet_double_blocks_17_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_17_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_17_txt_attn_qkv.alpha", + "lora_unet_double_blocks_17_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_17_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_17_txt_mlp_0.alpha", + "lora_unet_double_blocks_17_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_17_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_17_txt_mlp_2.alpha", + "lora_unet_double_blocks_17_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_17_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_17_txt_mod_lin.alpha", + "lora_unet_double_blocks_17_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_17_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_18_img_attn_proj.alpha", + "lora_unet_double_blocks_18_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_18_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_18_img_attn_qkv.alpha", + "lora_unet_double_blocks_18_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_18_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_18_img_mlp_0.alpha", + "lora_unet_double_blocks_18_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_18_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_18_img_mlp_2.alpha", + "lora_unet_double_blocks_18_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_18_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_18_img_mod_lin.alpha", + "lora_unet_double_blocks_18_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_18_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_18_txt_attn_proj.alpha", + "lora_unet_double_blocks_18_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_18_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_18_txt_attn_qkv.alpha", + "lora_unet_double_blocks_18_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_18_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_18_txt_mlp_0.alpha", + "lora_unet_double_blocks_18_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_18_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_18_txt_mlp_2.alpha", + "lora_unet_double_blocks_18_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_18_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_18_txt_mod_lin.alpha", + "lora_unet_double_blocks_18_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_18_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_1_img_attn_proj.alpha", + "lora_unet_double_blocks_1_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_1_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_1_img_attn_qkv.alpha", + "lora_unet_double_blocks_1_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_1_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_1_img_mlp_0.alpha", + "lora_unet_double_blocks_1_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_1_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_1_img_mlp_2.alpha", + "lora_unet_double_blocks_1_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_1_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_1_img_mod_lin.alpha", + "lora_unet_double_blocks_1_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_1_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_1_txt_attn_proj.alpha", + "lora_unet_double_blocks_1_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_1_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_1_txt_attn_qkv.alpha", + "lora_unet_double_blocks_1_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_1_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_1_txt_mlp_0.alpha", + "lora_unet_double_blocks_1_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_1_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_1_txt_mlp_2.alpha", + "lora_unet_double_blocks_1_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_1_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_1_txt_mod_lin.alpha", + "lora_unet_double_blocks_1_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_1_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_2_img_attn_proj.alpha", + "lora_unet_double_blocks_2_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_2_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_2_img_attn_qkv.alpha", + "lora_unet_double_blocks_2_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_2_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_2_img_mlp_0.alpha", + "lora_unet_double_blocks_2_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_2_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_2_img_mlp_2.alpha", + "lora_unet_double_blocks_2_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_2_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_2_img_mod_lin.alpha", + "lora_unet_double_blocks_2_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_2_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_2_txt_attn_proj.alpha", + "lora_unet_double_blocks_2_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_2_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_2_txt_attn_qkv.alpha", + "lora_unet_double_blocks_2_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_2_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_2_txt_mlp_0.alpha", + "lora_unet_double_blocks_2_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_2_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_2_txt_mlp_2.alpha", + "lora_unet_double_blocks_2_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_2_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_2_txt_mod_lin.alpha", + "lora_unet_double_blocks_2_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_2_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_3_img_attn_proj.alpha", + "lora_unet_double_blocks_3_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_3_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_3_img_attn_qkv.alpha", + "lora_unet_double_blocks_3_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_3_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_3_img_mlp_0.alpha", + "lora_unet_double_blocks_3_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_3_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_3_img_mlp_2.alpha", + "lora_unet_double_blocks_3_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_3_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_3_img_mod_lin.alpha", + "lora_unet_double_blocks_3_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_3_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_3_txt_attn_proj.alpha", + "lora_unet_double_blocks_3_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_3_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_3_txt_attn_qkv.alpha", + "lora_unet_double_blocks_3_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_3_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_3_txt_mlp_0.alpha", + "lora_unet_double_blocks_3_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_3_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_3_txt_mlp_2.alpha", + "lora_unet_double_blocks_3_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_3_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_3_txt_mod_lin.alpha", + "lora_unet_double_blocks_3_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_3_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_4_img_attn_proj.alpha", + "lora_unet_double_blocks_4_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_4_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_4_img_attn_qkv.alpha", + "lora_unet_double_blocks_4_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_4_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_4_img_mlp_0.alpha", + "lora_unet_double_blocks_4_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_4_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_4_img_mlp_2.alpha", + "lora_unet_double_blocks_4_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_4_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_4_img_mod_lin.alpha", + "lora_unet_double_blocks_4_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_4_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_4_txt_attn_proj.alpha", + "lora_unet_double_blocks_4_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_4_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_4_txt_attn_qkv.alpha", + "lora_unet_double_blocks_4_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_4_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_4_txt_mlp_0.alpha", + "lora_unet_double_blocks_4_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_4_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_4_txt_mlp_2.alpha", + "lora_unet_double_blocks_4_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_4_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_4_txt_mod_lin.alpha", + "lora_unet_double_blocks_4_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_4_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_5_img_attn_proj.alpha", + "lora_unet_double_blocks_5_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_5_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_5_img_attn_qkv.alpha", + "lora_unet_double_blocks_5_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_5_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_5_img_mlp_0.alpha", + "lora_unet_double_blocks_5_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_5_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_5_img_mlp_2.alpha", + "lora_unet_double_blocks_5_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_5_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_5_img_mod_lin.alpha", + "lora_unet_double_blocks_5_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_5_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_5_txt_attn_proj.alpha", + "lora_unet_double_blocks_5_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_5_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_5_txt_attn_qkv.alpha", + "lora_unet_double_blocks_5_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_5_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_5_txt_mlp_0.alpha", + "lora_unet_double_blocks_5_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_5_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_5_txt_mlp_2.alpha", + "lora_unet_double_blocks_5_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_5_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_5_txt_mod_lin.alpha", + "lora_unet_double_blocks_5_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_5_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_6_img_attn_proj.alpha", + "lora_unet_double_blocks_6_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_6_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_6_img_attn_qkv.alpha", + "lora_unet_double_blocks_6_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_6_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_6_img_mlp_0.alpha", + "lora_unet_double_blocks_6_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_6_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_6_img_mlp_2.alpha", + "lora_unet_double_blocks_6_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_6_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_6_img_mod_lin.alpha", + "lora_unet_double_blocks_6_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_6_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_6_txt_attn_proj.alpha", + "lora_unet_double_blocks_6_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_6_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_6_txt_attn_qkv.alpha", + "lora_unet_double_blocks_6_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_6_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_6_txt_mlp_0.alpha", + "lora_unet_double_blocks_6_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_6_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_6_txt_mlp_2.alpha", + "lora_unet_double_blocks_6_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_6_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_6_txt_mod_lin.alpha", + "lora_unet_double_blocks_6_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_6_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_7_img_attn_proj.alpha", + "lora_unet_double_blocks_7_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_7_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_7_img_attn_qkv.alpha", + "lora_unet_double_blocks_7_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_7_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_7_img_mlp_0.alpha", + "lora_unet_double_blocks_7_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_7_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_7_img_mlp_2.alpha", + "lora_unet_double_blocks_7_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_7_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_7_img_mod_lin.alpha", + "lora_unet_double_blocks_7_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_7_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_7_txt_attn_proj.alpha", + "lora_unet_double_blocks_7_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_7_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_7_txt_attn_qkv.alpha", + "lora_unet_double_blocks_7_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_7_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_7_txt_mlp_0.alpha", + "lora_unet_double_blocks_7_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_7_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_7_txt_mlp_2.alpha", + "lora_unet_double_blocks_7_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_7_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_7_txt_mod_lin.alpha", + "lora_unet_double_blocks_7_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_7_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_8_img_attn_proj.alpha", + "lora_unet_double_blocks_8_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_8_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_8_img_attn_qkv.alpha", + "lora_unet_double_blocks_8_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_8_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_8_img_mlp_0.alpha", + "lora_unet_double_blocks_8_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_8_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_8_img_mlp_2.alpha", + "lora_unet_double_blocks_8_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_8_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_8_img_mod_lin.alpha", + "lora_unet_double_blocks_8_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_8_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_8_txt_attn_proj.alpha", + "lora_unet_double_blocks_8_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_8_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_8_txt_attn_qkv.alpha", + "lora_unet_double_blocks_8_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_8_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_8_txt_mlp_0.alpha", + "lora_unet_double_blocks_8_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_8_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_8_txt_mlp_2.alpha", + "lora_unet_double_blocks_8_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_8_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_8_txt_mod_lin.alpha", + "lora_unet_double_blocks_8_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_8_txt_mod_lin.lora_up.weight", + "lora_unet_double_blocks_9_img_attn_proj.alpha", + "lora_unet_double_blocks_9_img_attn_proj.lora_down.weight", + "lora_unet_double_blocks_9_img_attn_proj.lora_up.weight", + "lora_unet_double_blocks_9_img_attn_qkv.alpha", + "lora_unet_double_blocks_9_img_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_9_img_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_9_img_mlp_0.alpha", + "lora_unet_double_blocks_9_img_mlp_0.lora_down.weight", + "lora_unet_double_blocks_9_img_mlp_0.lora_up.weight", + "lora_unet_double_blocks_9_img_mlp_2.alpha", + "lora_unet_double_blocks_9_img_mlp_2.lora_down.weight", + "lora_unet_double_blocks_9_img_mlp_2.lora_up.weight", + "lora_unet_double_blocks_9_img_mod_lin.alpha", + "lora_unet_double_blocks_9_img_mod_lin.lora_down.weight", + "lora_unet_double_blocks_9_img_mod_lin.lora_up.weight", + "lora_unet_double_blocks_9_txt_attn_proj.alpha", + "lora_unet_double_blocks_9_txt_attn_proj.lora_down.weight", + "lora_unet_double_blocks_9_txt_attn_proj.lora_up.weight", + "lora_unet_double_blocks_9_txt_attn_qkv.alpha", + "lora_unet_double_blocks_9_txt_attn_qkv.lora_down.weight", + "lora_unet_double_blocks_9_txt_attn_qkv.lora_up.weight", + "lora_unet_double_blocks_9_txt_mlp_0.alpha", + "lora_unet_double_blocks_9_txt_mlp_0.lora_down.weight", + "lora_unet_double_blocks_9_txt_mlp_0.lora_up.weight", + "lora_unet_double_blocks_9_txt_mlp_2.alpha", + "lora_unet_double_blocks_9_txt_mlp_2.lora_down.weight", + "lora_unet_double_blocks_9_txt_mlp_2.lora_up.weight", + "lora_unet_double_blocks_9_txt_mod_lin.alpha", + "lora_unet_double_blocks_9_txt_mod_lin.lora_down.weight", + "lora_unet_double_blocks_9_txt_mod_lin.lora_up.weight", + "lora_unet_single_blocks_0_linear1.alpha", + "lora_unet_single_blocks_0_linear1.lora_down.weight", + "lora_unet_single_blocks_0_linear1.lora_up.weight", + "lora_unet_single_blocks_0_linear2.alpha", + "lora_unet_single_blocks_0_linear2.lora_down.weight", + "lora_unet_single_blocks_0_linear2.lora_up.weight", + "lora_unet_single_blocks_0_modulation_lin.alpha", + "lora_unet_single_blocks_0_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_0_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_10_linear1.alpha", + "lora_unet_single_blocks_10_linear1.lora_down.weight", + "lora_unet_single_blocks_10_linear1.lora_up.weight", + "lora_unet_single_blocks_10_linear2.alpha", + "lora_unet_single_blocks_10_linear2.lora_down.weight", + "lora_unet_single_blocks_10_linear2.lora_up.weight", + "lora_unet_single_blocks_10_modulation_lin.alpha", + "lora_unet_single_blocks_10_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_10_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_11_linear1.alpha", + "lora_unet_single_blocks_11_linear1.lora_down.weight", + "lora_unet_single_blocks_11_linear1.lora_up.weight", + "lora_unet_single_blocks_11_linear2.alpha", + "lora_unet_single_blocks_11_linear2.lora_down.weight", + "lora_unet_single_blocks_11_linear2.lora_up.weight", + "lora_unet_single_blocks_11_modulation_lin.alpha", + "lora_unet_single_blocks_11_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_11_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_12_linear1.alpha", + "lora_unet_single_blocks_12_linear1.lora_down.weight", + "lora_unet_single_blocks_12_linear1.lora_up.weight", + "lora_unet_single_blocks_12_linear2.alpha", + "lora_unet_single_blocks_12_linear2.lora_down.weight", + "lora_unet_single_blocks_12_linear2.lora_up.weight", + "lora_unet_single_blocks_12_modulation_lin.alpha", + "lora_unet_single_blocks_12_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_12_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_13_linear1.alpha", + "lora_unet_single_blocks_13_linear1.lora_down.weight", + "lora_unet_single_blocks_13_linear1.lora_up.weight", + "lora_unet_single_blocks_13_linear2.alpha", + "lora_unet_single_blocks_13_linear2.lora_down.weight", + "lora_unet_single_blocks_13_linear2.lora_up.weight", + "lora_unet_single_blocks_13_modulation_lin.alpha", + "lora_unet_single_blocks_13_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_13_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_14_linear1.alpha", + "lora_unet_single_blocks_14_linear1.lora_down.weight", + "lora_unet_single_blocks_14_linear1.lora_up.weight", + "lora_unet_single_blocks_14_linear2.alpha", + "lora_unet_single_blocks_14_linear2.lora_down.weight", + "lora_unet_single_blocks_14_linear2.lora_up.weight", + "lora_unet_single_blocks_14_modulation_lin.alpha", + "lora_unet_single_blocks_14_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_14_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_15_linear1.alpha", + "lora_unet_single_blocks_15_linear1.lora_down.weight", + "lora_unet_single_blocks_15_linear1.lora_up.weight", + "lora_unet_single_blocks_15_linear2.alpha", + "lora_unet_single_blocks_15_linear2.lora_down.weight", + "lora_unet_single_blocks_15_linear2.lora_up.weight", + "lora_unet_single_blocks_15_modulation_lin.alpha", + "lora_unet_single_blocks_15_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_15_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_16_linear1.alpha", + "lora_unet_single_blocks_16_linear1.lora_down.weight", + "lora_unet_single_blocks_16_linear1.lora_up.weight", + "lora_unet_single_blocks_16_linear2.alpha", + "lora_unet_single_blocks_16_linear2.lora_down.weight", + "lora_unet_single_blocks_16_linear2.lora_up.weight", + "lora_unet_single_blocks_16_modulation_lin.alpha", + "lora_unet_single_blocks_16_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_16_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_17_linear1.alpha", + "lora_unet_single_blocks_17_linear1.lora_down.weight", + "lora_unet_single_blocks_17_linear1.lora_up.weight", + "lora_unet_single_blocks_17_linear2.alpha", + "lora_unet_single_blocks_17_linear2.lora_down.weight", + "lora_unet_single_blocks_17_linear2.lora_up.weight", + "lora_unet_single_blocks_17_modulation_lin.alpha", + "lora_unet_single_blocks_17_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_17_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_18_linear1.alpha", + "lora_unet_single_blocks_18_linear1.lora_down.weight", + "lora_unet_single_blocks_18_linear1.lora_up.weight", + "lora_unet_single_blocks_18_linear2.alpha", + "lora_unet_single_blocks_18_linear2.lora_down.weight", + "lora_unet_single_blocks_18_linear2.lora_up.weight", + "lora_unet_single_blocks_18_modulation_lin.alpha", + "lora_unet_single_blocks_18_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_18_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_19_linear1.alpha", + "lora_unet_single_blocks_19_linear1.lora_down.weight", + "lora_unet_single_blocks_19_linear1.lora_up.weight", + "lora_unet_single_blocks_19_linear2.alpha", + "lora_unet_single_blocks_19_linear2.lora_down.weight", + "lora_unet_single_blocks_19_linear2.lora_up.weight", + "lora_unet_single_blocks_19_modulation_lin.alpha", + "lora_unet_single_blocks_19_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_19_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_1_linear1.alpha", + "lora_unet_single_blocks_1_linear1.lora_down.weight", + "lora_unet_single_blocks_1_linear1.lora_up.weight", + "lora_unet_single_blocks_1_linear2.alpha", + "lora_unet_single_blocks_1_linear2.lora_down.weight", + "lora_unet_single_blocks_1_linear2.lora_up.weight", + "lora_unet_single_blocks_1_modulation_lin.alpha", + "lora_unet_single_blocks_1_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_1_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_20_linear1.alpha", + "lora_unet_single_blocks_20_linear1.lora_down.weight", + "lora_unet_single_blocks_20_linear1.lora_up.weight", + "lora_unet_single_blocks_20_linear2.alpha", + "lora_unet_single_blocks_20_linear2.lora_down.weight", + "lora_unet_single_blocks_20_linear2.lora_up.weight", + "lora_unet_single_blocks_20_modulation_lin.alpha", + "lora_unet_single_blocks_20_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_20_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_21_linear1.alpha", + "lora_unet_single_blocks_21_linear1.lora_down.weight", + "lora_unet_single_blocks_21_linear1.lora_up.weight", + "lora_unet_single_blocks_21_linear2.alpha", + "lora_unet_single_blocks_21_linear2.lora_down.weight", + "lora_unet_single_blocks_21_linear2.lora_up.weight", + "lora_unet_single_blocks_21_modulation_lin.alpha", + "lora_unet_single_blocks_21_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_21_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_22_linear1.alpha", + "lora_unet_single_blocks_22_linear1.lora_down.weight", + "lora_unet_single_blocks_22_linear1.lora_up.weight", + "lora_unet_single_blocks_22_linear2.alpha", + "lora_unet_single_blocks_22_linear2.lora_down.weight", + "lora_unet_single_blocks_22_linear2.lora_up.weight", + "lora_unet_single_blocks_22_modulation_lin.alpha", + "lora_unet_single_blocks_22_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_22_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_23_linear1.alpha", + "lora_unet_single_blocks_23_linear1.lora_down.weight", + "lora_unet_single_blocks_23_linear1.lora_up.weight", + "lora_unet_single_blocks_23_linear2.alpha", + "lora_unet_single_blocks_23_linear2.lora_down.weight", + "lora_unet_single_blocks_23_linear2.lora_up.weight", + "lora_unet_single_blocks_23_modulation_lin.alpha", + "lora_unet_single_blocks_23_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_23_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_24_linear1.alpha", + "lora_unet_single_blocks_24_linear1.lora_down.weight", + "lora_unet_single_blocks_24_linear1.lora_up.weight", + "lora_unet_single_blocks_24_linear2.alpha", + "lora_unet_single_blocks_24_linear2.lora_down.weight", + "lora_unet_single_blocks_24_linear2.lora_up.weight", + "lora_unet_single_blocks_24_modulation_lin.alpha", + "lora_unet_single_blocks_24_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_24_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_25_linear1.alpha", + "lora_unet_single_blocks_25_linear1.lora_down.weight", + "lora_unet_single_blocks_25_linear1.lora_up.weight", + "lora_unet_single_blocks_25_linear2.alpha", + "lora_unet_single_blocks_25_linear2.lora_down.weight", + "lora_unet_single_blocks_25_linear2.lora_up.weight", + "lora_unet_single_blocks_25_modulation_lin.alpha", + "lora_unet_single_blocks_25_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_25_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_26_linear1.alpha", + "lora_unet_single_blocks_26_linear1.lora_down.weight", + "lora_unet_single_blocks_26_linear1.lora_up.weight", + "lora_unet_single_blocks_26_linear2.alpha", + "lora_unet_single_blocks_26_linear2.lora_down.weight", + "lora_unet_single_blocks_26_linear2.lora_up.weight", + "lora_unet_single_blocks_26_modulation_lin.alpha", + "lora_unet_single_blocks_26_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_26_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_27_linear1.alpha", + "lora_unet_single_blocks_27_linear1.lora_down.weight", + "lora_unet_single_blocks_27_linear1.lora_up.weight", + "lora_unet_single_blocks_27_linear2.alpha", + "lora_unet_single_blocks_27_linear2.lora_down.weight", + "lora_unet_single_blocks_27_linear2.lora_up.weight", + "lora_unet_single_blocks_27_modulation_lin.alpha", + "lora_unet_single_blocks_27_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_27_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_28_linear1.alpha", + "lora_unet_single_blocks_28_linear1.lora_down.weight", + "lora_unet_single_blocks_28_linear1.lora_up.weight", + "lora_unet_single_blocks_28_linear2.alpha", + "lora_unet_single_blocks_28_linear2.lora_down.weight", + "lora_unet_single_blocks_28_linear2.lora_up.weight", + "lora_unet_single_blocks_28_modulation_lin.alpha", + "lora_unet_single_blocks_28_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_28_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_29_linear1.alpha", + "lora_unet_single_blocks_29_linear1.lora_down.weight", + "lora_unet_single_blocks_29_linear1.lora_up.weight", + "lora_unet_single_blocks_29_linear2.alpha", + "lora_unet_single_blocks_29_linear2.lora_down.weight", + "lora_unet_single_blocks_29_linear2.lora_up.weight", + "lora_unet_single_blocks_29_modulation_lin.alpha", + "lora_unet_single_blocks_29_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_29_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_2_linear1.alpha", + "lora_unet_single_blocks_2_linear1.lora_down.weight", + "lora_unet_single_blocks_2_linear1.lora_up.weight", + "lora_unet_single_blocks_2_linear2.alpha", + "lora_unet_single_blocks_2_linear2.lora_down.weight", + "lora_unet_single_blocks_2_linear2.lora_up.weight", + "lora_unet_single_blocks_2_modulation_lin.alpha", + "lora_unet_single_blocks_2_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_2_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_30_linear1.alpha", + "lora_unet_single_blocks_30_linear1.lora_down.weight", + "lora_unet_single_blocks_30_linear1.lora_up.weight", + "lora_unet_single_blocks_30_linear2.alpha", + "lora_unet_single_blocks_30_linear2.lora_down.weight", + "lora_unet_single_blocks_30_linear2.lora_up.weight", + "lora_unet_single_blocks_30_modulation_lin.alpha", + "lora_unet_single_blocks_30_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_30_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_31_linear1.alpha", + "lora_unet_single_blocks_31_linear1.lora_down.weight", + "lora_unet_single_blocks_31_linear1.lora_up.weight", + "lora_unet_single_blocks_31_linear2.alpha", + "lora_unet_single_blocks_31_linear2.lora_down.weight", + "lora_unet_single_blocks_31_linear2.lora_up.weight", + "lora_unet_single_blocks_31_modulation_lin.alpha", + "lora_unet_single_blocks_31_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_31_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_32_linear1.alpha", + "lora_unet_single_blocks_32_linear1.lora_down.weight", + "lora_unet_single_blocks_32_linear1.lora_up.weight", + "lora_unet_single_blocks_32_linear2.alpha", + "lora_unet_single_blocks_32_linear2.lora_down.weight", + "lora_unet_single_blocks_32_linear2.lora_up.weight", + "lora_unet_single_blocks_32_modulation_lin.alpha", + "lora_unet_single_blocks_32_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_32_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_33_linear1.alpha", + "lora_unet_single_blocks_33_linear1.lora_down.weight", + "lora_unet_single_blocks_33_linear1.lora_up.weight", + "lora_unet_single_blocks_33_linear2.alpha", + "lora_unet_single_blocks_33_linear2.lora_down.weight", + "lora_unet_single_blocks_33_linear2.lora_up.weight", + "lora_unet_single_blocks_33_modulation_lin.alpha", + "lora_unet_single_blocks_33_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_33_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_34_linear1.alpha", + "lora_unet_single_blocks_34_linear1.lora_down.weight", + "lora_unet_single_blocks_34_linear1.lora_up.weight", + "lora_unet_single_blocks_34_linear2.alpha", + "lora_unet_single_blocks_34_linear2.lora_down.weight", + "lora_unet_single_blocks_34_linear2.lora_up.weight", + "lora_unet_single_blocks_34_modulation_lin.alpha", + "lora_unet_single_blocks_34_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_34_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_35_linear1.alpha", + "lora_unet_single_blocks_35_linear1.lora_down.weight", + "lora_unet_single_blocks_35_linear1.lora_up.weight", + "lora_unet_single_blocks_35_linear2.alpha", + "lora_unet_single_blocks_35_linear2.lora_down.weight", + "lora_unet_single_blocks_35_linear2.lora_up.weight", + "lora_unet_single_blocks_35_modulation_lin.alpha", + "lora_unet_single_blocks_35_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_35_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_36_linear1.alpha", + "lora_unet_single_blocks_36_linear1.lora_down.weight", + "lora_unet_single_blocks_36_linear1.lora_up.weight", + "lora_unet_single_blocks_36_linear2.alpha", + "lora_unet_single_blocks_36_linear2.lora_down.weight", + "lora_unet_single_blocks_36_linear2.lora_up.weight", + "lora_unet_single_blocks_36_modulation_lin.alpha", + "lora_unet_single_blocks_36_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_36_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_37_linear1.alpha", + "lora_unet_single_blocks_37_linear1.lora_down.weight", + "lora_unet_single_blocks_37_linear1.lora_up.weight", + "lora_unet_single_blocks_37_linear2.alpha", + "lora_unet_single_blocks_37_linear2.lora_down.weight", + "lora_unet_single_blocks_37_linear2.lora_up.weight", + "lora_unet_single_blocks_37_modulation_lin.alpha", + "lora_unet_single_blocks_37_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_37_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_3_linear1.alpha", + "lora_unet_single_blocks_3_linear1.lora_down.weight", + "lora_unet_single_blocks_3_linear1.lora_up.weight", + "lora_unet_single_blocks_3_linear2.alpha", + "lora_unet_single_blocks_3_linear2.lora_down.weight", + "lora_unet_single_blocks_3_linear2.lora_up.weight", + "lora_unet_single_blocks_3_modulation_lin.alpha", + "lora_unet_single_blocks_3_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_3_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_4_linear1.alpha", + "lora_unet_single_blocks_4_linear1.lora_down.weight", + "lora_unet_single_blocks_4_linear1.lora_up.weight", + "lora_unet_single_blocks_4_linear2.alpha", + "lora_unet_single_blocks_4_linear2.lora_down.weight", + "lora_unet_single_blocks_4_linear2.lora_up.weight", + "lora_unet_single_blocks_4_modulation_lin.alpha", + "lora_unet_single_blocks_4_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_4_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_5_linear1.alpha", + "lora_unet_single_blocks_5_linear1.lora_down.weight", + "lora_unet_single_blocks_5_linear1.lora_up.weight", + "lora_unet_single_blocks_5_linear2.alpha", + "lora_unet_single_blocks_5_linear2.lora_down.weight", + "lora_unet_single_blocks_5_linear2.lora_up.weight", + "lora_unet_single_blocks_5_modulation_lin.alpha", + "lora_unet_single_blocks_5_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_5_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_6_linear1.alpha", + "lora_unet_single_blocks_6_linear1.lora_down.weight", + "lora_unet_single_blocks_6_linear1.lora_up.weight", + "lora_unet_single_blocks_6_linear2.alpha", + "lora_unet_single_blocks_6_linear2.lora_down.weight", + "lora_unet_single_blocks_6_linear2.lora_up.weight", + "lora_unet_single_blocks_6_modulation_lin.alpha", + "lora_unet_single_blocks_6_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_6_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_7_linear1.alpha", + "lora_unet_single_blocks_7_linear1.lora_down.weight", + "lora_unet_single_blocks_7_linear1.lora_up.weight", + "lora_unet_single_blocks_7_linear2.alpha", + "lora_unet_single_blocks_7_linear2.lora_down.weight", + "lora_unet_single_blocks_7_linear2.lora_up.weight", + "lora_unet_single_blocks_7_modulation_lin.alpha", + "lora_unet_single_blocks_7_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_7_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_8_linear1.alpha", + "lora_unet_single_blocks_8_linear1.lora_down.weight", + "lora_unet_single_blocks_8_linear1.lora_up.weight", + "lora_unet_single_blocks_8_linear2.alpha", + "lora_unet_single_blocks_8_linear2.lora_down.weight", + "lora_unet_single_blocks_8_linear2.lora_up.weight", + "lora_unet_single_blocks_8_modulation_lin.alpha", + "lora_unet_single_blocks_8_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_8_modulation_lin.lora_up.weight", + "lora_unet_single_blocks_9_linear1.alpha", + "lora_unet_single_blocks_9_linear1.lora_down.weight", + "lora_unet_single_blocks_9_linear1.lora_up.weight", + "lora_unet_single_blocks_9_linear2.alpha", + "lora_unet_single_blocks_9_linear2.lora_down.weight", + "lora_unet_single_blocks_9_linear2.lora_up.weight", + "lora_unet_single_blocks_9_modulation_lin.alpha", + "lora_unet_single_blocks_9_modulation_lin.lora_down.weight", + "lora_unet_single_blocks_9_modulation_lin.lora_up.weight", +] diff --git a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py index b34c404e2f..e713b02c02 100644 --- a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py @@ -5,7 +5,7 @@ import torch from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import params from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import ( - convert_flux_kohya_state_dict_to_invoke_format, + _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, ) @@ -15,13 +15,17 @@ from tests.backend.lora.conversions.lora_state_dicts.flux_lora_diffusers_format from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_format import ( state_dict_keys as flux_kohya_state_dict_keys, ) +from tests.backend.lora.conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import ( + state_dict_keys as flux_kohya_te1_state_dict_keys, +) from tests.backend.lora.conversions.lora_state_dicts.utils import keys_to_mock_state_dict -def test_is_state_dict_likely_in_flux_kohya_format_true(): +@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys]) +def test_is_state_dict_likely_in_flux_kohya_format_true(sd_keys: list[str]): """Test that is_state_dict_likely_in_flux_kohya_format() can identify a state dict in the Kohya FLUX LoRA format.""" # Construct a state dict that is in the Kohya FLUX LoRA format. - state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) + state_dict = keys_to_mock_state_dict(sd_keys) assert is_state_dict_likely_in_flux_kohya_format(state_dict) @@ -34,11 +38,11 @@ def test_is_state_dict_likely_in_flux_kohya_format_false(): assert not is_state_dict_likely_in_flux_kohya_format(state_dict) -def test_convert_flux_kohya_state_dict_to_invoke_format(): +def test_convert_flux_transformer_kohya_state_dict_to_invoke_format(): # Construct state_dict from state_dict_keys. state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) - converted_state_dict = convert_flux_kohya_state_dict_to_invoke_format(state_dict) + converted_state_dict = _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict) # Extract the prefixes from the converted state dict (i.e. without the .lora_up.weight, .lora_down.weight, and # .alpha suffixes). @@ -65,29 +69,33 @@ def test_convert_flux_kohya_state_dict_to_invoke_format(): raise AssertionError(f"Could not find a match for the converted key prefix: {converted_key_prefix}") -def test_convert_flux_kohya_state_dict_to_invoke_format_error(): - """Test that an error is raised by convert_flux_kohya_state_dict_to_invoke_format() if the input state_dict contains - unexpected keys. +def test_convert_flux_transformer_kohya_state_dict_to_invoke_format_error(): + """Test that an error is raised by _convert_flux_transformer_kohya_state_dict_to_invoke_format() if the input + state_dict contains unexpected keys. """ state_dict = { "unexpected_key.lora_up.weight": torch.empty(1), } with pytest.raises(ValueError): - convert_flux_kohya_state_dict_to_invoke_format(state_dict) + _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict) -def test_lora_model_from_flux_kohya_state_dict(): +@pytest.mark.parametrize("sd_keys", [flux_kohya_state_dict_keys, flux_kohya_te1_state_dict_keys]) +def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]): """Test that a LoRAModelRaw can be created from a state dict in the Kohya FLUX LoRA format.""" # Construct a state dict that is in the Kohya FLUX LoRA format. - state_dict = keys_to_mock_state_dict(flux_kohya_state_dict_keys) + state_dict = keys_to_mock_state_dict(sd_keys) lora_model = lora_model_from_flux_kohya_state_dict(state_dict) # Prepare expected layer keys. expected_layer_keys: set[str] = set() - for k in flux_kohya_state_dict_keys: + for k in sd_keys: + # Remove prefixes. k = k.replace("lora_unet_", "") + k = k.replace("lora_te1_", "") + # Remove suffixes. k = k.replace(".lora_up.weight", "") k = k.replace(".lora_down.weight", "") k = k.replace(".alpha", "") From 249da858df4fb480f510dc1053129008a1cfca9f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Sep 2024 21:28:25 +0000 Subject: [PATCH 2/7] Add support for FLUX T5 text encoder LoRA models to invocations. --- invokeai/app/invocations/flux_lora_loader.py | 68 ++++++++++++++++---- invokeai/app/invocations/model.py | 3 +- 2 files changed, 57 insertions(+), 14 deletions(-) diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py index 46f593ea9f..e82556c74e 100644 --- a/invokeai/app/invocations/flux_lora_loader.py +++ b/invokeai/app/invocations/flux_lora_loader.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType -from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, TransformerField +from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, T5EncoderField, TransformerField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import BaseModelType @@ -20,6 +20,9 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): transformer: Optional[TransformerField] = OutputField( default=None, description=FieldDescriptions.transformer, title="FLUX Transformer" ) + t5_encoder: Optional[T5EncoderField] = OutputField( + default=None, description=FieldDescriptions.t5_encoder, title="T5Encoder" + ) @invocation( @@ -27,21 +30,28 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): title="FLUX LoRA", tags=["lora", "model", "flux"], category="model", - version="1.0.0", + version="1.1.0", classification=Classification.Prototype, ) class FluxLoRALoaderInvocation(BaseInvocation): - """Apply a LoRA model to a FLUX transformer.""" + """Apply a LoRA model to a FLUX transformer and/or T5 encoder.""" lora: ModelIdentifierField = InputField( description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel ) weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight) - transformer: TransformerField = InputField( + transformer: TransformerField | None = InputField( + default=None, description=FieldDescriptions.transformer, input=Input.Connection, title="FLUX Transformer", ) + t5_encoder: T5EncoderField | None = InputField( + default=None, + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: lora_key = self.lora.key @@ -49,18 +59,33 @@ class FluxLoRALoaderInvocation(BaseInvocation): if not context.models.exists(lora_key): raise ValueError(f"Unknown lora: {lora_key}!") - if any(lora.lora.key == lora_key for lora in self.transformer.loras): + # Check for existing LoRAs with the same key. + if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') + if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.') - transformer = self.transformer.model_copy(deep=True) - transformer.loras.append( - LoRAField( - lora=self.lora, - weight=self.weight, + output = FluxLoRALoaderOutput() + + # Attach LoRA layers to the models. + if self.transformer is not None: + output.transformer = self.transformer.model_copy(deep=True) + output.transformer.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) + ) + if self.t5_encoder is not None: + output.t5_encoder = self.t5_encoder.model_copy(deep=True) + output.t5_encoder.loras.append( + LoRAField( + lora=self.lora, + weight=self.weight, + ) ) - ) - return FluxLoRALoaderOutput(transformer=transformer) + return output @invocation( @@ -68,7 +93,7 @@ class FluxLoRALoaderInvocation(BaseInvocation): title="FLUX LoRA Collection Loader", tags=["lora", "model", "flux"], category="model", - version="1.0.0", + version="1.1.0", classification=Classification.Prototype, ) class FLUXLoRACollectionLoader(BaseInvocation): @@ -84,6 +109,18 @@ class FLUXLoRACollectionLoader(BaseInvocation): input=Input.Connection, title="Transformer", ) + transformer: TransformerField | None = InputField( + default=None, + description=FieldDescriptions.transformer, + input=Input.Connection, + title="FLUX Transformer", + ) + t5_encoder: T5EncoderField | None = InputField( + default=None, + title="T5Encoder", + description=FieldDescriptions.t5_encoder, + input=Input.Connection, + ) def invoke(self, context: InvocationContext) -> FluxLoRALoaderOutput: output = FluxLoRALoaderOutput() @@ -106,4 +143,9 @@ class FLUXLoRACollectionLoader(BaseInvocation): output.transformer = self.transformer.model_copy(deep=True) output.transformer.loras.append(lora) + if self.t5_encoder is not None: + if output.t5_encoder is None: + output.t5_encoder = self.t5_encoder.model_copy(deep=True) + output.t5_encoder.loras.append(lora) + return output diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index c0d0a4a7f7..0b87a5cd34 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -75,6 +75,7 @@ class TransformerField(BaseModel): class T5EncoderField(BaseModel): tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") + loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") class VAEField(BaseModel): @@ -205,7 +206,7 @@ class FluxModelLoaderInvocation(BaseInvocation): return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), - t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), + t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), vae=VAEField(vae=vae), max_seq_len=max_seq_lengths[transformer_config.config_path], ) From 7d38a9b7fbec24fff25769638cb265afc2bc8160 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 26 Sep 2024 22:04:54 +0000 Subject: [PATCH 3/7] Add prefix to distinguish FLUX LoRA submodels. --- invokeai/app/invocations/flux_denoise.py | 5 +- invokeai/app/invocations/flux_text_encoder.py | 48 ++++++++++++++++++- .../flux_kohya_lora_conversion_utils.py | 12 +++-- .../test_flux_kohya_lora_conversion_utils.py | 8 ++-- 4 files changed, 63 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index ce5357fcb9..9e24e5ebfd 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -30,6 +30,7 @@ from invokeai.backend.flux.sampling_utils import ( pack, unpack, ) +from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_TRANFORMER_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat @@ -208,7 +209,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): LoRAPatcher.apply_lora_patches( model=transformer, patches=self._lora_iterator(context), - prefix="", + prefix=FLUX_KOHYA_TRANFORMER_PREFIX, cached_weights=cached_weights, ) ) @@ -219,7 +220,7 @@ class FluxDenoiseInvocation(BaseInvocation, WithMetadata, WithBoard): LoRAPatcher.apply_lora_sidecar_patches( model=transformer, patches=self._lora_iterator(context), - prefix="", + prefix=FLUX_KOHYA_TRANFORMER_PREFIX, dtype=inference_dtype, ) ) diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index a19dda30b8..7fa059981f 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -1,4 +1,5 @@ -from typing import Literal +from contextlib import ExitStack +from typing import Iterator, Literal, Tuple import torch from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer @@ -9,6 +10,10 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder +from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_T5_PREFIX +from invokeai.backend.lora.lora_model_raw import LoRAModelRaw +from invokeai.backend.lora.lora_patcher import LoRAPatcher +from invokeai.backend.model_manager.config import ModelFormat from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData, FLUXConditioningInfo @@ -58,12 +63,44 @@ class FluxTextEncoderInvocation(BaseInvocation): prompt = [self.prompt] with ( - t5_text_encoder_info as t5_text_encoder, + t5_text_encoder_info.model_on_device() as (cached_weights, t5_text_encoder), t5_tokenizer_info as t5_tokenizer, + ExitStack() as exit_stack, ): assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_tokenizer, T5Tokenizer) + t5_text_encoder_config = t5_text_encoder_info.config + assert t5_text_encoder_config is not None + + # Apply LoRA models to the T5 encoder. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + if t5_text_encoder_config.format in [ModelFormat.T5Encoder]: + # The model is non-quantized, so we can apply the LoRA weights directly into the model. + exit_stack.enter_context( + LoRAPatcher.apply_lora_patches( + model=t5_text_encoder, + patches=self._lora_iterator(context), + prefix=FLUX_KOHYA_T5_PREFIX, + cached_weights=cached_weights, + ) + ) + elif t5_text_encoder_config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]: + # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, + # than directly patching the weights, but is agnostic to the quantization format. + exit_stack.enter_context( + LoRAPatcher.apply_lora_sidecar_patches( + model=t5_text_encoder, + patches=self._lora_iterator(context), + prefix=FLUX_KOHYA_T5_PREFIX, + dtype=t5_text_encoder.dtype, + ) + ) + elif t5_text_encoder_config.format in [ModelFormat.BnbQuantizedLlmInt8b]: + pass + else: + raise ValueError(f"Unsupported model format: {t5_text_encoder_config.format}") + t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) prompt_embeds = t5_encoder(prompt) @@ -90,3 +127,10 @@ class FluxTextEncoderInvocation(BaseInvocation): assert isinstance(pooled_prompt_embeds, torch.Tensor) return pooled_prompt_embeds + + def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.t5_encoder.loras: + lora_info = context.models.load(lora.lora) + assert isinstance(lora_info.model, LoRAModelRaw) + yield (lora_info.model, lora.weight) + del lora_info diff --git a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py index 83e61384b9..ed9b35bb70 100644 --- a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py @@ -1,4 +1,3 @@ -import itertools import re from typing import Any, Dict, TypeVar @@ -24,6 +23,11 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( FLUX_KOHYA_T5_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" +# Prefixes used to distinguish between transformer and T5 keys in the InvokeAI LoRA format. +FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-" +FLUX_KOHYA_T5_PREFIX = "lora_t5-" + + def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. @@ -61,8 +65,10 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - # Create LoRA layers. layers: dict[str, AnyLoRALayer] = {} - for layer_key, layer_state_dict in itertools.chain(transformer_grouped_sd.items(), t5_grouped_sd.items()): - layers[layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + for layer_key, layer_state_dict in transformer_grouped_sd.items(): + layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + for layer_key, layer_state_dict in t5_grouped_sd.items(): + layers[FLUX_KOHYA_T5_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) # Create and return the LoRAModelRaw. return LoRAModelRaw(layers=layers) diff --git a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py index e713b02c02..41ee91c51f 100644 --- a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py @@ -5,6 +5,8 @@ import torch from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import params from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import ( + FLUX_KOHYA_T5_PREFIX, + FLUX_KOHYA_TRANFORMER_PREFIX, _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, lora_model_from_flux_kohya_state_dict, @@ -92,9 +94,9 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]): # Prepare expected layer keys. expected_layer_keys: set[str] = set() for k in sd_keys: - # Remove prefixes. - k = k.replace("lora_unet_", "") - k = k.replace("lora_te1_", "") + # Replace prefixes. + k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX) + k = k.replace("lora_te1_", FLUX_KOHYA_T5_PREFIX) # Remove suffixes. k = k.replace(".lora_up.weight", "") k = k.replace(".lora_down.weight", "") From c2568260154f0bdb5ab74b656beff5230f53d700 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Sep 2024 14:00:12 +0000 Subject: [PATCH 4/7] Whoops, the 'lora_te1' prefix in FLUX kohya models refers to the CLIP text encoder - not the T5 as previously assumed. Update everything accordingly. --- invokeai/app/invocations/flux_lora_loader.py | 42 +++++------- invokeai/app/invocations/flux_text_encoder.py | 64 ++++++++----------- invokeai/app/invocations/model.py | 3 +- .../flux_kohya_lora_conversion_utils.py | 30 +++++---- .../flux_lora_kohya_with_te1_format.py | 2 +- .../test_flux_kohya_lora_conversion_utils.py | 4 +- 6 files changed, 63 insertions(+), 82 deletions(-) diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py index e82556c74e..3cfbb87851 100644 --- a/invokeai/app/invocations/flux_lora_loader.py +++ b/invokeai/app/invocations/flux_lora_loader.py @@ -8,7 +8,7 @@ from invokeai.app.invocations.baseinvocation import ( invocation_output, ) from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType -from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, T5EncoderField, TransformerField +from invokeai.app.invocations.model import CLIPField, LoRAField, ModelIdentifierField, TransformerField from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.model_manager.config import BaseModelType @@ -20,9 +20,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): transformer: Optional[TransformerField] = OutputField( default=None, description=FieldDescriptions.transformer, title="FLUX Transformer" ) - t5_encoder: Optional[T5EncoderField] = OutputField( - default=None, description=FieldDescriptions.t5_encoder, title="T5Encoder" - ) + clip: Optional[CLIPField] = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") @invocation( @@ -46,10 +44,10 @@ class FluxLoRALoaderInvocation(BaseInvocation): input=Input.Connection, title="FLUX Transformer", ) - t5_encoder: T5EncoderField | None = InputField( + clip: CLIPField | None = InputField( default=None, - title="T5Encoder", - description=FieldDescriptions.t5_encoder, + title="CLIP", + description=FieldDescriptions.clip, input=Input.Connection, ) @@ -62,8 +60,8 @@ class FluxLoRALoaderInvocation(BaseInvocation): # Check for existing LoRAs with the same key. if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras): raise ValueError(f'LoRA "{lora_key}" already applied to transformer.') - if self.t5_encoder and any(lora.lora.key == lora_key for lora in self.t5_encoder.loras): - raise ValueError(f'LoRA "{lora_key}" already applied to T5 encoder.') + if self.clip and any(lora.lora.key == lora_key for lora in self.clip.loras): + raise ValueError(f'LoRA "{lora_key}" already applied to CLIP encoder.') output = FluxLoRALoaderOutput() @@ -76,9 +74,9 @@ class FluxLoRALoaderInvocation(BaseInvocation): weight=self.weight, ) ) - if self.t5_encoder is not None: - output.t5_encoder = self.t5_encoder.model_copy(deep=True) - output.t5_encoder.loras.append( + if self.clip is not None: + output.clip = self.clip.model_copy(deep=True) + output.clip.loras.append( LoRAField( lora=self.lora, weight=self.weight, @@ -109,16 +107,10 @@ class FLUXLoRACollectionLoader(BaseInvocation): input=Input.Connection, title="Transformer", ) - transformer: TransformerField | None = InputField( + clip: CLIPField | None = InputField( default=None, - description=FieldDescriptions.transformer, - input=Input.Connection, - title="FLUX Transformer", - ) - t5_encoder: T5EncoderField | None = InputField( - default=None, - title="T5Encoder", - description=FieldDescriptions.t5_encoder, + title="CLIP", + description=FieldDescriptions.clip, input=Input.Connection, ) @@ -143,9 +135,9 @@ class FLUXLoRACollectionLoader(BaseInvocation): output.transformer = self.transformer.model_copy(deep=True) output.transformer.loras.append(lora) - if self.t5_encoder is not None: - if output.t5_encoder is None: - output.t5_encoder = self.t5_encoder.model_copy(deep=True) - output.t5_encoder.loras.append(lora) + if self.clip is not None: + if output.clip is None: + output.clip = self.clip.model_copy(deep=True) + output.clip.loras.append(lora) return output diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index 7fa059981f..ac70273317 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -10,7 +10,7 @@ from invokeai.app.invocations.model import CLIPField, T5EncoderField from invokeai.app.invocations.primitives import FluxConditioningOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.flux.modules.conditioner import HFEncoder -from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_T5_PREFIX +from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import FLUX_KOHYA_CLIP_PREFIX from invokeai.backend.lora.lora_model_raw import LoRAModelRaw from invokeai.backend.lora.lora_patcher import LoRAPatcher from invokeai.backend.model_manager.config import ModelFormat @@ -22,7 +22,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import Condit title="FLUX Text Encoding", tags=["prompt", "conditioning", "flux"], category="conditioning", - version="1.0.0", + version="1.1.0", classification=Classification.Prototype, ) class FluxTextEncoderInvocation(BaseInvocation): @@ -63,44 +63,12 @@ class FluxTextEncoderInvocation(BaseInvocation): prompt = [self.prompt] with ( - t5_text_encoder_info.model_on_device() as (cached_weights, t5_text_encoder), + t5_text_encoder_info as t5_text_encoder, t5_tokenizer_info as t5_tokenizer, - ExitStack() as exit_stack, ): assert isinstance(t5_text_encoder, T5EncoderModel) assert isinstance(t5_tokenizer, T5Tokenizer) - t5_text_encoder_config = t5_text_encoder_info.config - assert t5_text_encoder_config is not None - - # Apply LoRA models to the T5 encoder. - # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. - if t5_text_encoder_config.format in [ModelFormat.T5Encoder]: - # The model is non-quantized, so we can apply the LoRA weights directly into the model. - exit_stack.enter_context( - LoRAPatcher.apply_lora_patches( - model=t5_text_encoder, - patches=self._lora_iterator(context), - prefix=FLUX_KOHYA_T5_PREFIX, - cached_weights=cached_weights, - ) - ) - elif t5_text_encoder_config.format in [ModelFormat.BnbQuantizedLlmInt8b, ModelFormat.BnbQuantizednf4b]: - # The model is quantized, so apply the LoRA weights as sidecar layers. This results in slower inference, - # than directly patching the weights, but is agnostic to the quantization format. - exit_stack.enter_context( - LoRAPatcher.apply_lora_sidecar_patches( - model=t5_text_encoder, - patches=self._lora_iterator(context), - prefix=FLUX_KOHYA_T5_PREFIX, - dtype=t5_text_encoder.dtype, - ) - ) - elif t5_text_encoder_config.format in [ModelFormat.BnbQuantizedLlmInt8b]: - pass - else: - raise ValueError(f"Unsupported model format: {t5_text_encoder_config.format}") - t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len) prompt_embeds = t5_encoder(prompt) @@ -115,12 +83,32 @@ class FluxTextEncoderInvocation(BaseInvocation): prompt = [self.prompt] with ( - clip_text_encoder_info as clip_text_encoder, + clip_text_encoder_info.model_on_device() as (cached_weights, clip_text_encoder), clip_tokenizer_info as clip_tokenizer, + ExitStack() as exit_stack, ): assert isinstance(clip_text_encoder, CLIPTextModel) assert isinstance(clip_tokenizer, CLIPTokenizer) + clip_text_encoder_config = clip_text_encoder_info.config + assert clip_text_encoder_config is not None + + # Apply LoRA models to the T5 encoder. + # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. + if clip_text_encoder_config.format in [ModelFormat.Diffusers]: + # The model is non-quantized, so we can apply the LoRA weights directly into the model. + exit_stack.enter_context( + LoRAPatcher.apply_lora_patches( + model=clip_text_encoder, + patches=self._clip_lora_iterator(context), + prefix=FLUX_KOHYA_CLIP_PREFIX, + cached_weights=cached_weights, + ) + ) + else: + # There are currently no supported CLIP quantized models. Add support here if needed. + raise ValueError(f"Unsupported model format: {clip_text_encoder_config.format}") + clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77) pooled_prompt_embeds = clip_encoder(prompt) @@ -128,8 +116,8 @@ class FluxTextEncoderInvocation(BaseInvocation): assert isinstance(pooled_prompt_embeds, torch.Tensor) return pooled_prompt_embeds - def _lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]: - for lora in self.t5_encoder.loras: + def _clip_lora_iterator(self, context: InvocationContext) -> Iterator[Tuple[LoRAModelRaw, float]]: + for lora in self.clip.loras: lora_info = context.models.load(lora.lora) assert isinstance(lora_info.model, LoRAModelRaw) yield (lora_info.model, lora.weight) diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 0b87a5cd34..c0d0a4a7f7 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -75,7 +75,6 @@ class TransformerField(BaseModel): class T5EncoderField(BaseModel): tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") - loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") class VAEField(BaseModel): @@ -206,7 +205,7 @@ class FluxModelLoaderInvocation(BaseInvocation): return FluxModelLoaderOutput( transformer=TransformerField(transformer=transformer, loras=[]), clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), - t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder, loras=[]), + t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), vae=VAEField(vae=vae), max_seq_len=max_seq_lengths[transformer_config.config_path], ) diff --git a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py index ed9b35bb70..94df3d5567 100644 --- a/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/lora/conversions/flux_kohya_lora_conversion_utils.py @@ -15,17 +15,17 @@ from invokeai.backend.lora.lora_model_raw import LoRAModelRaw FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( r"lora_unet_(\w+_blocks)_(\d+)_(img_attn|img_mlp|img_mod|txt_attn|txt_mlp|txt_mod|linear1|linear2|modulation)_?(.*)" ) -# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format. +# A regex pattern that matches all of the CLIP keys in the Kohya FLUX LoRA format. # Example keys: # lora_te1_text_model_encoder_layers_0_mlp_fc1.alpha # lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_down.weight # lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight -FLUX_KOHYA_T5_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" +FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" -# Prefixes used to distinguish between transformer and T5 keys in the InvokeAI LoRA format. +# Prefixes used to distinguish between transformer and CLIP text encoder keys in the InvokeAI LoRA format. FLUX_KOHYA_TRANFORMER_PREFIX = "lora_transformer-" -FLUX_KOHYA_T5_PREFIX = "lora_t5-" +FLUX_KOHYA_CLIP_PREFIX = "lora_clip-" def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: @@ -35,7 +35,8 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ return all( - re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() + re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) + for k in state_dict.keys() ) @@ -48,27 +49,27 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - grouped_state_dict[layer_name] = {} grouped_state_dict[layer_name][param_name] = value - # Split the grouped state dict into transformer and T5 state dicts. + # Split the grouped state dict into transformer and CLIP state dicts. transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} - t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} for layer_name, layer_state_dict in grouped_state_dict.items(): if layer_name.startswith("lora_unet"): transformer_grouped_sd[layer_name] = layer_state_dict elif layer_name.startswith("lora_te1"): - t5_grouped_sd[layer_name] = layer_state_dict + clip_grouped_sd[layer_name] = layer_state_dict else: raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.") # Convert the state dicts to the InvokeAI format. transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd) - t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd) + clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd) # Create LoRA layers. layers: dict[str, AnyLoRALayer] = {} for layer_key, layer_state_dict in transformer_grouped_sd.items(): layers[FLUX_KOHYA_TRANFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) - for layer_key, layer_state_dict in t5_grouped_sd.items(): - layers[FLUX_KOHYA_T5_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + for layer_key, layer_state_dict in clip_grouped_sd.items(): + layers[FLUX_KOHYA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) # Create and return the LoRAModelRaw. return LoRAModelRaw(layers=layers) @@ -77,8 +78,9 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - T = TypeVar("T") -def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: - """Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by InvokeAI. +def _convert_flux_clip_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: + """Converts a CLIP LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by + InvokeAI. Example key conversions: @@ -87,7 +89,7 @@ def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) """ converted_sd: dict[str, T] = {} for k, v in state_dict.items(): - match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k) + match = re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) if match: new_key = f"text_model.encoder.layers.{match.group(1)}.{match.group(2)}.{match.group(3)}" converted_sd[new_key] = v diff --git a/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py index c43505e9c0..f7689936fa 100644 --- a/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py +++ b/tests/backend/lora/conversions/lora_state_dicts/flux_lora_kohya_with_te1_format.py @@ -1,4 +1,4 @@ -# A sample state dict in the Kohya FLUX LoRA format that patches both the transformer and T5 text encoder. +# A sample state dict in the Kohya FLUX LoRA format that patches both the transformer and CLIP text encoder. # These keys are based on the LoRA model here: # https://huggingface.co/cocktailpeanut/optimus state_dict_keys = [ diff --git a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py index 41ee91c51f..7878137130 100644 --- a/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py +++ b/tests/backend/lora/conversions/test_flux_kohya_lora_conversion_utils.py @@ -5,7 +5,7 @@ import torch from invokeai.backend.flux.model import Flux from invokeai.backend.flux.util import params from invokeai.backend.lora.conversions.flux_kohya_lora_conversion_utils import ( - FLUX_KOHYA_T5_PREFIX, + FLUX_KOHYA_CLIP_PREFIX, FLUX_KOHYA_TRANFORMER_PREFIX, _convert_flux_transformer_kohya_state_dict_to_invoke_format, is_state_dict_likely_in_flux_kohya_format, @@ -96,7 +96,7 @@ def test_lora_model_from_flux_kohya_state_dict(sd_keys: list[str]): for k in sd_keys: # Replace prefixes. k = k.replace("lora_unet_", FLUX_KOHYA_TRANFORMER_PREFIX) - k = k.replace("lora_te1_", FLUX_KOHYA_T5_PREFIX) + k = k.replace("lora_te1_", FLUX_KOHYA_CLIP_PREFIX) # Remove suffixes. k = k.replace(".lora_up.weight", "") k = k.replace(".lora_down.weight", "") From 3463a968c7d9305a02c2f3d6010720da2a7c0f42 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Sep 2024 14:27:41 +0000 Subject: [PATCH 5/7] Update Linear UI to support FLUX LoRA models that patch the CLIP model in addition to the transformer. --- .../util/graph/generation/addFLUXLoRAs.ts | 11 ++++++---- .../util/graph/generation/buildFLUXGraph.ts | 6 ++--- .../frontend/web/src/services/api/schema.ts | 22 +++++++++++++++++-- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts index e9e755451d..cf77adf234 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts @@ -8,7 +8,8 @@ export const addFLUXLoRAs = ( state: RootState, g: Graph, denoise: Invocation<'flux_denoise'>, - modelLoader: Invocation<'flux_model_loader'> + modelLoader: Invocation<'flux_model_loader'>, + fluxTextEncoder: Invocation<'flux_text_encoder'>, ): void => { const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux'); const loraCount = enabledLoRAs.length; @@ -20,7 +21,7 @@ export const addFLUXLoRAs = ( const loraMetadata: S['LoRAMetadataField'][] = []; // We will collect LoRAs into a single collection node, then pass them to the LoRA collection loader, which applies - // each LoRA to the UNet and CLIP. + // each LoRA to the transformer and text encoders. const loraCollector = g.addNode({ id: getPrefixedId('lora_collector'), type: 'collect', @@ -33,10 +34,12 @@ export const addFLUXLoRAs = ( g.addEdge(loraCollector, 'collection', loraCollectionLoader, 'loras'); // Use model loader as transformer input g.addEdge(modelLoader, 'transformer', loraCollectionLoader, 'transformer'); - // Reroute transformer connections through the LoRA collection loader + g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip'); + // Reroute model connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['transformer']); - + g.deleteEdgesTo(fluxTextEncoder, ['clip']) g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer'); + g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip'); for (const lora of enabledLoRAs) { const { weight } = lora; diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 7d48480c1d..488b19f5bf 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -6,6 +6,7 @@ import { selectCanvasSettingsSlice } from 'features/controlLayers/store/canvasSe import { selectParamsSlice } from 'features/controlLayers/store/paramsSlice'; import { selectCanvasMetadata, selectCanvasSlice } from 'features/controlLayers/store/selectors'; import { fetchModelConfigWithTypeGuard } from 'features/metadata/util/modelFetchingHelpers'; +import { addFLUXLoRAs } from 'features/nodes/util/graph/generation/addFLUXLoRAs'; import { addImageToImage } from 'features/nodes/util/graph/generation/addImageToImage'; import { addInpaint } from 'features/nodes/util/graph/generation/addInpaint'; import { addNSFWChecker } from 'features/nodes/util/graph/generation/addNSFWChecker'; @@ -18,7 +19,6 @@ import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; -import { addFLUXLoRAs } from './addFLUXLoRAs'; const log = logger('system'); @@ -96,12 +96,12 @@ export const buildFLUXGraph = async ( g.addEdge(modelLoader, 'transformer', noise, 'transformer'); g.addEdge(modelLoader, 'vae', l2i, 'vae'); - addFLUXLoRAs(state, g, noise, modelLoader); - g.addEdge(modelLoader, 'clip', posCond, 'clip'); g.addEdge(modelLoader, 't5_encoder', posCond, 't5_encoder'); g.addEdge(modelLoader, 'max_seq_len', posCond, 't5_max_seq_len'); + addFLUXLoRAs(state, g, noise, modelLoader, posCond); + g.addEdge(posCond, 'conditioning', noise, 'positive_text_conditioning'); g.addEdge(noise, 'latents', l2i, 'latents'); diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index b525d93de4..c53c3d56b7 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -5707,6 +5707,12 @@ export type components = { * @default null */ transformer?: components["schemas"]["TransformerField"] | null; + /** + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null + */ + clip?: components["schemas"]["CLIPField"] | null; /** * type * @default flux_lora_collection_loader @@ -6391,7 +6397,7 @@ export type components = { }; /** * FLUX LoRA - * @description Apply a LoRA model to a FLUX transformer. + * @description Apply a LoRA model to a FLUX transformer and/or T5 encoder. */ FluxLoRALoaderInvocation: { /** @@ -6428,7 +6434,13 @@ export type components = { * @description Transformer * @default null */ - transformer?: components["schemas"]["TransformerField"]; + transformer?: components["schemas"]["TransformerField"] | null; + /** + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null + */ + clip?: components["schemas"]["CLIPField"] | null; /** * type * @default flux_lora_loader @@ -6448,6 +6460,12 @@ export type components = { * @default null */ transformer: components["schemas"]["TransformerField"] | null; + /** + * CLIP + * @description CLIP (tokenizer, text encoder, LoRAs) and skipped layer count + * @default null + */ + clip: components["schemas"]["CLIPField"] | null; /** * type * @default flux_lora_loader_output From ba8ef6ff0f4bc828d371f9f96e7b29c540de4aad Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Sep 2024 14:46:25 +0000 Subject: [PATCH 6/7] (minor) remove remaining incorrect references to T5 encoder in comments. --- invokeai/app/invocations/flux_lora_loader.py | 2 +- invokeai/app/invocations/flux_text_encoder.py | 2 +- invokeai/frontend/web/src/services/api/schema.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/invokeai/app/invocations/flux_lora_loader.py b/invokeai/app/invocations/flux_lora_loader.py index 3cfbb87851..d9e655a507 100644 --- a/invokeai/app/invocations/flux_lora_loader.py +++ b/invokeai/app/invocations/flux_lora_loader.py @@ -32,7 +32,7 @@ class FluxLoRALoaderOutput(BaseInvocationOutput): classification=Classification.Prototype, ) class FluxLoRALoaderInvocation(BaseInvocation): - """Apply a LoRA model to a FLUX transformer and/or T5 encoder.""" + """Apply a LoRA model to a FLUX transformer and/or text encoder.""" lora: ModelIdentifierField = InputField( description=FieldDescriptions.lora_model, title="LoRA", ui_type=UIType.LoRAModel diff --git a/invokeai/app/invocations/flux_text_encoder.py b/invokeai/app/invocations/flux_text_encoder.py index ac70273317..a306a8aa95 100644 --- a/invokeai/app/invocations/flux_text_encoder.py +++ b/invokeai/app/invocations/flux_text_encoder.py @@ -93,7 +93,7 @@ class FluxTextEncoderInvocation(BaseInvocation): clip_text_encoder_config = clip_text_encoder_info.config assert clip_text_encoder_config is not None - # Apply LoRA models to the T5 encoder. + # Apply LoRA models to the CLIP encoder. # Note: We apply the LoRA after the transformer has been moved to its target device for faster patching. if clip_text_encoder_config.format in [ModelFormat.Diffusers]: # The model is non-quantized, so we can apply the LoRA weights directly into the model. diff --git a/invokeai/frontend/web/src/services/api/schema.ts b/invokeai/frontend/web/src/services/api/schema.ts index c53c3d56b7..777e7bfd30 100644 --- a/invokeai/frontend/web/src/services/api/schema.ts +++ b/invokeai/frontend/web/src/services/api/schema.ts @@ -6397,7 +6397,7 @@ export type components = { }; /** * FLUX LoRA - * @description Apply a LoRA model to a FLUX transformer and/or T5 encoder. + * @description Apply a LoRA model to a FLUX transformer and/or text encoder. */ FluxLoRALoaderInvocation: { /** From a424552c82eb83e367979fdc5f6f2029049c30ae Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 27 Sep 2024 15:04:48 +0000 Subject: [PATCH 7/7] Fix frontend lint errors. --- .../src/features/nodes/util/graph/generation/addFLUXLoRAs.ts | 4 ++-- .../features/nodes/util/graph/generation/buildFLUXGraph.ts | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts index cf77adf234..a57e655c73 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/addFLUXLoRAs.ts @@ -9,7 +9,7 @@ export const addFLUXLoRAs = ( g: Graph, denoise: Invocation<'flux_denoise'>, modelLoader: Invocation<'flux_model_loader'>, - fluxTextEncoder: Invocation<'flux_text_encoder'>, + fluxTextEncoder: Invocation<'flux_text_encoder'> ): void => { const enabledLoRAs = state.loras.loras.filter((l) => l.isEnabled && l.model.base === 'flux'); const loraCount = enabledLoRAs.length; @@ -37,7 +37,7 @@ export const addFLUXLoRAs = ( g.addEdge(modelLoader, 'clip', loraCollectionLoader, 'clip'); // Reroute model connections through the LoRA collection loader g.deleteEdgesTo(denoise, ['transformer']); - g.deleteEdgesTo(fluxTextEncoder, ['clip']) + g.deleteEdgesTo(fluxTextEncoder, ['clip']); g.addEdge(loraCollectionLoader, 'transformer', denoise, 'transformer'); g.addEdge(loraCollectionLoader, 'clip', fluxTextEncoder, 'clip'); diff --git a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts index 488b19f5bf..50e55526b0 100644 --- a/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts +++ b/invokeai/frontend/web/src/features/nodes/util/graph/generation/buildFLUXGraph.ts @@ -19,7 +19,6 @@ import type { Invocation } from 'services/api/types'; import { isNonRefinerMainModelConfig } from 'services/api/types'; import { assert } from 'tsafe'; - const log = logger('system'); export const buildFLUXGraph = async (