mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Assume LoRA alpha=8 for FLUX diffusers PEFT LoRAs.
This commit is contained in:
@@ -30,8 +30,7 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
return all_keys_in_peft_format and all_expected_keys_present
|
||||
|
||||
|
||||
# TODO(ryand): What alpha should we use? 1.0? Rank of the LoRA?
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float = 1.0) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
def lora_model_from_flux_diffusers_state_dict(state_dict: Dict[str, torch.Tensor], alpha: float) -> LoRAModelRaw: # pyright: ignore[reportRedeclaration] (state_dict is intentionally re-declared)
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
|
||||
This function is based on:
|
||||
|
||||
@@ -68,7 +68,13 @@ class LoRALoader(ModelLoader):
|
||||
model = lora_model_from_sd_state_dict(state_dict=state_dict)
|
||||
elif self._model_base == BaseModelType.Flux:
|
||||
if config.format == ModelFormat.Diffusers:
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict)
|
||||
# HACK(ryand): We assume alpha=8 for diffusers PEFT format models. These models are typically
|
||||
# distributed as a single file without the associated metadata containing the alpha value. We chose
|
||||
# alpha=8, because this is the default value in the PEFT library:
|
||||
# https://github.com/huggingface/peft/blob/7868d0372b86a6b9ac5f365b8f0eef2f2f5dedce/src/peft/tuners/lora/config.py#L169
|
||||
# Other reasonable defaults for alpha could be 1.0 or the rank of the LoRA. If our assumption is wrong,
|
||||
# the user will need to adjust the weight accordingly to account for the difference.
|
||||
model = lora_model_from_flux_diffusers_state_dict(state_dict=state_dict, alpha=8)
|
||||
elif config.format == ModelFormat.LyCORIS:
|
||||
model = lora_model_from_flux_kohya_state_dict(state_dict=state_dict)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user