Assume LoRA alpha=8 for FLUX diffusers PEFT LoRAs.

This commit is contained in:
Ryan Dick
2024-09-12 14:01:41 +00:00
committed by Kent Keirsey
parent 10c3c61cb2
commit 81fbaf2b8b
2 changed files with 8 additions and 3 deletions

View File

@@ -30,8 +30,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
return all_keys_in_peft_format and all_expected_keys_present
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
This function is based on:

View File

@@ -68,7 +68,13 @@ class LoRALoader(ModelLoader):
model = lora_model_from_sd_state_dict(state_dict=state_dict)
elif self._model_base == BaseModelType.Flux:
if config.format == ModelFormat.Diffusers:
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
# HACK(ryand): We assume alpha=8 for diffusers PEFT format models. These models are typically
# distributed as a single file without the associated metadata containing the alpha value. We chose
# alpha=8, because this is the default value in the PEFT library:
# https://github.com/huggingface/peft/blob/7868d0372b86a6b9ac5f365b8f0eef2f2f5dedce/src/peft/tuners/lora/config.py#L169
# Other reasonable defaults for alpha could be 1.0 or the rank of the LoRA. If our assumption is wrong,
# the user will need to adjust the weight accordingly to account for the difference.
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=8)
elif config.format == ModelFormat.LyCORIS:
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
else: