mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-01-14 16:47:55 -05:00
Add support for LyCoris-style LoRA keys in lora_model_from_flux_diffusers_state_dict(). Previously, it only supported PEFT-style LoRA keys.
This commit is contained in:
@@ -4,7 +4,7 @@ import torch
|
||||
|
||||
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
|
||||
from invokeai.backend.patches.layers.concatenated_lora_layer import ConcatenatedLoRALayer
|
||||
from invokeai.backend.patches.layers.lora_layer import LoRALayer
|
||||
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
|
||||
from invokeai.backend.patches.lora_conversions.flux_lora_constants import FLUX_LORA_TRANSFORMER_PREFIX
|
||||
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
|
||||
|
||||
@@ -53,16 +53,25 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
|
||||
layers: dict[str, BaseLayerPatch] = {}
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
value = {
|
||||
def get_lora_layer_values(src_layer_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
||||
if "lora_A.weight" in src_layer_dict:
|
||||
# The LoRA keys are in PEFT format.
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
value["alpha"] = torch.tensor(alpha)
|
||||
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
return values
|
||||
else:
|
||||
# Assume that the LoRA keys are in Kohya format.
|
||||
return src_layer_dict
|
||||
|
||||
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
|
||||
if src_key in grouped_state_dict:
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
values = get_lora_layer_values(src_layer_dict)
|
||||
layers[dst_key] = any_lora_layer_from_state_dict(values)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(
|
||||
@@ -79,19 +88,14 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
if not any(keys_present):
|
||||
return
|
||||
|
||||
sub_layers: list[LoRALayer] = []
|
||||
sub_layers: list[BaseLayerPatch] = []
|
||||
for src_key, src_weight_shape in zip(src_keys, src_weight_shapes, strict=True):
|
||||
src_layer_dict = grouped_state_dict.pop(src_key, None)
|
||||
if src_layer_dict is not None:
|
||||
values = {
|
||||
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
|
||||
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
values = get_lora_layer_values(src_layer_dict)
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
sub_layers.append(any_lora_layer_from_state_dict(values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
@@ -100,7 +104,7 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
"lora_up.weight": torch.zeros((src_weight_shape[0], 1)),
|
||||
"lora_down.weight": torch.zeros((1, src_weight_shape[1])),
|
||||
}
|
||||
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
|
||||
sub_layers.append(any_lora_layer_from_state_dict(values))
|
||||
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers)
|
||||
|
||||
# time_text_embed.timestep_embedder -> time_in.
|
||||
|
||||
Reference in New Issue
Block a user