mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
Further updates to lora_model_from_flux_diffusers_state_dict() so that it can be re-used for OneTrainer LoRAs.
This commit is contained in:
@@ -33,13 +33,21 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: Dict[str, torch.Te
|
||||
def lora_model_from_flux_diffusers_state_dict(
|
||||
state_dict: Dict[str, torch.Tensor], alpha: float | None
|
||||
) -> ModelPatchRaw:
|
||||
"""Loads a state dict in the Diffusers FLUX LoRA format into a LoRAModelRaw object.
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
layers = lora_layers_from_flux_diffusers_grouped_state_dict(grouped_state_dict, alpha)
|
||||
return ModelPatchRaw(layers=layers)
|
||||
|
||||
|
||||
def lora_layers_from_flux_diffusers_grouped_state_dict(
|
||||
grouped_state_dict: Dict[str, Dict[str, torch.Tensor]], alpha: float | None
|
||||
) -> dict[str, BaseLayerPatch]:
|
||||
"""Converts a grouped state dict with Diffusers FLUX LoRA keys to LoRA layers with BFL keys (i.e. the module key
|
||||
format used by Invoke).
|
||||
|
||||
This function is based on:
|
||||
https://github.com/huggingface/diffusers/blob/55ac421f7bb12fd00ccbef727be4dc2f3f920abb/scripts/convert_flux_to_diffusers.py
|
||||
"""
|
||||
# Group keys by layer.
|
||||
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = _group_by_layer(state_dict)
|
||||
|
||||
# Remove the "transformer." prefix from all keys.
|
||||
grouped_state_dict = {k.replace("transformer.", ""): v for k, v in grouped_state_dict.items()}
|
||||
@@ -62,6 +70,7 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
}
|
||||
if alpha is not None:
|
||||
values["alpha"] = torch.tensor(alpha)
|
||||
assert len(src_layer_dict) == 0
|
||||
return values
|
||||
else:
|
||||
# Assume that the LoRA keys are in Kohya format.
|
||||
@@ -72,7 +81,6 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
src_layer_dict = grouped_state_dict.pop(src_key)
|
||||
values = get_lora_layer_values(src_layer_dict)
|
||||
layers[dst_key] = any_lora_layer_from_state_dict(values)
|
||||
assert len(src_layer_dict) == 0
|
||||
|
||||
def add_qkv_lora_layer_if_present(
|
||||
src_keys: list[str],
|
||||
@@ -96,7 +104,6 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
assert values["lora_down.weight"].shape[1] == src_weight_shape[1]
|
||||
assert values["lora_up.weight"].shape[0] == src_weight_shape[0]
|
||||
sub_layers.append(any_lora_layer_from_state_dict(values))
|
||||
assert len(src_layer_dict) == 0
|
||||
else:
|
||||
if not allow_missing_keys:
|
||||
raise ValueError(f"Missing LoRA layer: '{src_key}'.")
|
||||
@@ -221,7 +228,7 @@ def lora_model_from_flux_diffusers_state_dict(
|
||||
|
||||
layers_with_prefix = {f"{FLUX_LORA_TRANSFORMER_PREFIX}{k}": v for k, v in layers.items()}
|
||||
|
||||
return ModelPatchRaw(layers=layers_with_prefix)
|
||||
return layers_with_prefix
|
||||
|
||||
|
||||
def _group_by_layer(state_dict: Dict[str, torch.Tensor]) -> dict[str, dict[str, torch.Tensor]]:
|
||||
|
||||
Reference in New Issue
Block a user