diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index b4fa481468..8c64287c84 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -4,7 +4,7 @@ import torch from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer -from invokeai.backend.patches.layers.lora_layer import LoRALayer +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -53,16 +53,25 @@ def lora_model_from_flux_diffusers_state_dict( layers: dict[str, BaseLayerPatch] = {} - def add_lora_layer_if_present(src_key: str, dst_key: str) -> None: - if src_key in grouped_state_dict: - src_layer_dict = grouped_state_dict.pop(src_key) - value = { + def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + if "lora_A.weight" in src_layer_dict: + # The LoRA keys are in PEFT format. + values = { "lora_down.weight": src_layer_dict.pop("lora_A.weight"), "lora_up.weight": src_layer_dict.pop("lora_B.weight"), } if alpha is not None: - value["alpha"] = torch.tensor(alpha) - layers[dst_key] = LoRALayer.from_state_dict_values(values=value) + values["alpha"] = torch.tensor(alpha) + return values + else: + # Assume that the LoRA keys are in Kohya format. + return src_layer_dict + + def add_lora_layer_if_present(src_key: str, dst_key: str) -> None: + if src_key in grouped_state_dict: + src_layer_dict = grouped_state_dict.pop(src_key) + values = get_lora_layer_values(src_layer_dict) + layers[dst_key] = any_lora_layer_from_state_dict(values) assert len(src_layer_dict) == 0 def add_qkv_lora_layer_if_present( @@ -79,19 +88,14 @@ def lora_model_from_flux_diffusers_state_dict( if not any(keys_present): return - sub_layers: list[LoRALayer] = [] + sub_layers: list[BaseLayerPatch] = [] for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True): src_layer_dict = grouped_state_dict.pop(src_key, None) if src_layer_dict is not None: - values = { - "lora_down.weight": src_layer_dict.pop("lora_A.weight"), - "lora_up.weight": src_layer_dict.pop("lora_B.weight"), - } - if alpha is not None: - values["alpha"] = torch.tensor(alpha) + values = get_lora_layer_values(src_layer_dict) assert values["lora_down.weight"].shape[1] == src_weight_shape[1] assert values["lora_up.weight"].shape[0] == src_weight_shape[0] - sub_layers.append(LoRALayer.from_state_dict_values(values=values)) + sub_layers.append(any_lora_layer_from_state_dict(values)) assert len(src_layer_dict) == 0 else: if not allow_missing_keys: @@ -100,7 +104,7 @@ def lora_model_from_flux_diffusers_state_dict( "lora_up.weight": torch.zeros((src_weight_shape[0], 1)), "lora_down.weight": torch.zeros((1, src_weight_shape[1])), } - sub_layers.append(LoRALayer.from_state_dict_values(values=values)) + sub_layers.append(any_lora_layer_from_state_dict(values)) layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers) # time_text_embed.timestep_embedder -> time_in.