Add support for LyCoris-style LoRA keys in lora_model_from_flux_diffusers_state_dict(). Previously, it only supported PEFT-style LoRA keys.

This commit is contained in:
Ryan Dick
2025-01-22 17:20:28 +00:00
parent dfa253e75b
commit 908976ac08

View File

@@ -4,7 +4,7 @@ import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -53,16 +53,25 @@ def lora_model_from_flux_diffusers_state_dict(
layers: dict[str, BaseLayerPatch] = {}
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
value = {
def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
if "lora_A.weight" in src_layer_dict:
# The LoRA keys are in PEFT format.
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
value["alpha"] = torch.tensor(alpha)
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
values["alpha"] = torch.tensor(alpha)
return values
else:
# Assume that the LoRA keys are in Kohya format.
return src_layer_dict
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(
@@ -79,19 +88,14 @@ def lora_model_from_flux_diffusers_state_dict(
if not any(keys_present):
return
sub_layers: list[LoRALayer] = []
sub_layers: list[BaseLayerPatch] = []
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
src_layer_dict = grouped_state_dict.pop(src_key, None)
if src_layer_dict is not None:
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
values = get_lora_layer_values(src_layer_dict)
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
sub_layers.append(any_lora_layer_from_state_dict(values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
@@ -100,7 +104,7 @@ def lora_model_from_flux_diffusers_state_dict(
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
}
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
sub_layers.append(any_lora_layer_from_state_dict(values))
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
# time_text_embed.timestep_embedder -> time_in.