Further updates to lora_model_from_flux_diffusers_state_dict() so that it can be re-used for OneTrainer LoRAs.

This commit is contained in:
Ryan Dick
2025-01-22 20:39:15 +00:00
parent 908976ac08
commit 7eee4da896

View File

@@ -33,13 +33,21 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
def lora_model_from_flux_diffusers_state_dict(
state_dict: Dict[str, torch.Tensor], alpha: float | None
) -> ModelPatchRaw:
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
layers = lora_layers_from_flux_diffusers_grouped_state_dict(grouped_state_dict, alpha)
return ModelPatchRaw(layers=layers)
def lora_layers_from_flux_diffusers_grouped_state_dict(
grouped_state_dict: Dict[str, Dict[str, torch.Tensor]], alpha: float | None
) -> dict[str, BaseLayerPatch]:
"""Converts a grouped state dict with Diffusers FLUX LoRA keys to LoRA layers with BFL keys (i.e. the module key
format used by Invoke).
This function is based on:
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
"""
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
# Remove the "transformer." prefix from all keys.
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
@@ -62,6 +70,7 @@ def lora_model_from_flux_diffusers_state_dict(
}
if alpha is not None:
values["alpha"] = torch.tensor(alpha)
assert len(src_layer_dict) == 0
return values
else:
# Assume that the LoRA keys are in Kohya format.
@@ -72,7 +81,6 @@ def lora_model_from_flux_diffusers_state_dict(
src_layer_dict = grouped_state_dict.pop(src_key)
values = get_lora_layer_values(src_layer_dict)
layers[dst_key] = any_lora_layer_from_state_dict(values)
assert len(src_layer_dict) == 0
def add_qkv_lora_layer_if_present(
src_keys: list[str],
@@ -96,7 +104,6 @@ def lora_model_from_flux_diffusers_state_dict(
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
sub_layers.append(any_lora_layer_from_state_dict(values))
assert len(src_layer_dict) == 0
else:
if not allow_missing_keys:
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
@@ -221,7 +228,7 @@ def lora_model_from_flux_diffusers_state_dict(
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
return ModelPatchRaw(layers=layers_with_prefix)
return layers_with_prefix
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]: