Assume alpha=rank for FLUX diffusers PEFT LoRA models.

This commit is contained in:
Ryan Dick
2024-09-16 13:57:07 +00:00
parent d51f2c5e00
commit e88d3cf2f7
2 changed files with 21 additions and 25 deletions

View File

@@ -5,7 +5,6 @@ import torch
from invokeai.backend.lora.layers.any_lora_layer import AnyLoRALayer
from invokeai.backend.lora.layers.concatenated_lora_layer import ConcatenatedLoRALayer
from invokeai.backend.lora.layers.lora_layer import LoRALayer
from invokeai.backend.lora.layers.lora_layer_base import LoRALayerBase
from invokeai.backend.lora.lora_model_raw import LoRAModelRaw
@@ -30,7 +29,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw:
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float | None) -> LoRAModelRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:
@@ -53,13 +52,13 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
def add_lora_layer_if_present(src_key: str, dst_key: str) -> None:
if src_key in grouped_state_dict:
src_layer_dict = grouped_state_dict.pop(src_key)
layers[dst_key] = LoRALayer.from_state_dict_values(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
value = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
value["alpha"] = torch.tensor(alpha)
layers[dst_key] = LoRALayer.from_state_dict_values(values=value)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(src_keys: list[str], dst_qkv_key: str) -> None:
@@ -75,17 +74,15 @@ def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor
return
src_layer_dicts = [grouped_state_dict.pop(key) for key in src_keys]
sub_layers: list[LoRALayerBase] = []
sub_layers: list[LoRALayer] = []
for src_layer_dict in src_layer_dicts:
sub_layers.append(
LoRALayer.from_state_dict_values(
values={
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
"alpha": torch.tensor(alpha),
},
)
)
values = {
"lora_down.weight": src_layer_dict.pop("lora_A.weight"),
"lora_up.weight": src_layer_dict.pop("lora_B.weight"),
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
sub_layers.append(LoRALayer.from_state_dict_values(values=values))
assert len(src_layer_dict) == 0
layers[dst_qkv_key] = ConcatenatedLoRALayer(lora_layers=sub_layers, concat_axis=0)

View File

@@ -68,13 +68,12 @@ class LoRALoader(ModelLoader):
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format == ModelFormat.Diffusers:
# HACK(ryand): We assume alpha=8 for diffusers PEFT format models. These models are typically
# HACK(ryand): We set alpha=None for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=8, because this is the default value in the PEFT library:
# https://github.com/huggingface/peft/blob/7868d0372b86a6b9ac5f365b8f0eef2f2f5dedce/src/peft/tuners/lora/config.py#L169
# Other reasonable defaults for alpha could be 1.0 or the rank of the LoRA. If our assumption is wrong,
# the user will need to adjust the weight accordingly to account for the difference.
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=8)
# alpha=None, because this is treated as alpha=rank internally in `LoRALayerBase.scale()`. alpha=rank
# is a popular choice. For example, in the diffusers training scripts:
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1194
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=None)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
else: