Add utils for loading FLUX OneTrainer DoRA models.

This commit is contained in:
Ryan Dick
2025-01-22 20:41:14 +00:00
parent 7eee4da896
commit 206f261e45
5 changed files with 203 additions and 8 deletions

View File

@@ -2,7 +2,6 @@ from typing import Optional, Sequence
import torch
from invokeai.backend.patches.layers.lora_layer import LoRALayer
from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase
@@ -14,7 +13,7 @@ class ConcatenatedLoRALayer(LoRALayerBase):
stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models.
"""
def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0):
def __init__(self, lora_layers: Sequence[LoRALayerBase], concat_axis: int = 0):
super().__init__(alpha=None, bias=None)
self.lora_layers = lora_layers

View File

@@ -7,6 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
FLUX_LORA_TRANSFORMER_PREFIX,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
@@ -32,7 +33,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight
# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)\.?.*"
FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*"
def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool:
@@ -58,27 +59,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) -
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer and CLIP state dicts.
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_unet"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd)
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for layer_key, layer_state_dict in transformer_grouped_sd.items():
layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for layer_key, layer_state_dict in clip_grouped_sd.items():
layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
for model_prefix, grouped_sd in [
(FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
@@ -133,3 +141,31 @@ def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict
def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]:
"""Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by
InvokeAI.
Example key conversions:
"lora_te2_encoder_block_0_layer_0_SelfAttention_k" -> "encoder.block.0.layer.0.SelfAttention.k"
"lora_te2_encoder_block_0_layer_1_DenseReluDense_wi_0" -> "encoder.block.0.layer.1.DenseReluDense.wi.0"
"""
def replace_func(match: re.Match[str]) -> str:
s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}.{match.group(4)}"
if match.group(5):
s += f".{match.group(5)}"
return "encoder.block." + s
converted_dict: dict[str, T] = {}
for k, v in state_dict.items():
match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
if match:
new_key = re.sub(FLUX_KOHYA_T5_KEY_REGEX, replace_func, k)
converted_dict[new_key] = v
else:
raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.")
return converted_dict

View File

@@ -1,3 +1,4 @@
# Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format.
FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-"
FLUX_LORA_CLIP_PREFIX = "lora_clip-"
FLUX_LORA_T5_PREFIX = "lora_t5-"

View File

@@ -1,10 +1,29 @@
import re
from typing import Any, Dict
import torch
from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch
from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict
from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import (
lora_layers_from_flux_diffusers_grouped_state_dict,
)
from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import (
FLUX_KOHYA_CLIP_KEY_REGEX,
FLUX_KOHYA_T5_KEY_REGEX,
_convert_flux_clip_kohya_state_dict_to_invoke_format,
_convert_flux_t5_kohya_state_dict_to_invoke_format,
)
from invokeai.backend.patches.lora_conversions.flux_lora_constants import (
FLUX_LORA_CLIP_PREFIX,
FLUX_LORA_T5_PREFIX,
)
from invokeai.backend.patches.lora_conversions.kohya_key_utils import (
INDEX_PLACEHOLDER,
ParsingTree,
insert_periods_into_kohya_key,
)
from invokeai.backend.patches.model_patch_raw import ModelPatchRaw
# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format.
# The OneTrainer format uses a mix of the Kohya and Diffusers formats:
@@ -35,3 +54,110 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -
or re.match(FLUX_KOHYA_T5_KEY_REGEX, k)
for k in state_dict.keys()
)
def lora_model_from_flux_onetrainer_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: # type: ignore
# Group keys by layer.
grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {}
for key, value in state_dict.items():
layer_name, param_name = key.split(".", 1)
if layer_name not in grouped_state_dict:
grouped_state_dict[layer_name] = {}
grouped_state_dict[layer_name][param_name] = value
# Split the grouped state dict into transformer, CLIP, and T5 state dicts.
transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {}
for layer_name, layer_state_dict in grouped_state_dict.items():
if layer_name.startswith("lora_transformer"):
transformer_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te1"):
clip_grouped_sd[layer_name] = layer_state_dict
elif layer_name.startswith("lora_te2"):
t5_grouped_sd[layer_name] = layer_state_dict
else:
raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.")
# Convert the state dicts to the InvokeAI format.
clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd)
t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd)
# Create LoRA layers.
layers: dict[str, BaseLayerPatch] = {}
for model_prefix, grouped_sd in [
# (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd),
(FLUX_LORA_CLIP_PREFIX, clip_grouped_sd),
(FLUX_LORA_T5_PREFIX, t5_grouped_sd),
]:
for layer_key, layer_state_dict in grouped_sd.items():
layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict)
# Handle the transformer.
transformer_layers = _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(transformer_grouped_sd)
layers.update(transformer_layers)
# Create and return the LoRAModelRaw.
return ModelPatchRaw(layers=layers)
# This parsing tree was generated by calling `generate_kohya_parsing_tree_from_keys()` on the keys in
# flux_lora_diffusers_format.py.
flux_transformer_kohya_parsing_tree: ParsingTree = {
"transformer": {
"single_transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {"to_k": {}, "to_q": {}, "to_v": {}},
"norm": {"linear": {}},
"proj_mlp": {},
"proj_out": {},
}
},
"transformer_blocks": {
INDEX_PLACEHOLDER: {
"attn": {
"add_k_proj": {},
"add_q_proj": {},
"add_v_proj": {},
"to_add_out": {},
"to_k": {},
"to_out": {INDEX_PLACEHOLDER: {}},
"to_q": {},
"to_v": {},
},
"ff": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"ff_context": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}},
"norm1": {"linear": {}},
"norm1_context": {"linear": {}},
}
},
}
}
def _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(
state_dict: Dict[str, Dict[str, torch.Tensor]],
) -> dict[str, BaseLayerPatch]:
"""Converts a FLUX transformer LoRA state dict from the OneTrainer FLUX LoRA format to the LoRA weight format used
internally by InvokeAI.
"""
# Step 1: Convert the Kohya-style keys with underscores to classic keys with periods.
# Example:
# "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" -> "transformer.single_transformer_blocks.0.attn.to_k.lora_down.weight"
lora_prefix = "lora_"
lora_prefix_length = len(lora_prefix)
kohya_state_dict: dict[str, Dict[str, torch.Tensor]] = {}
for key in state_dict.keys():
# Remove the "lora_" prefix.
assert key.startswith(lora_prefix)
new_key = key[lora_prefix_length:]
# Add periods to the Kohya-style module keys.
new_key = insert_periods_into_kohya_key(new_key, flux_transformer_kohya_parsing_tree)
# Replace the old key with the new key.
kohya_state_dict[new_key] = state_dict[key]
# Step 2: Convert diffusers module names to the BFL module names.
return lora_layers_from_flux_diffusers_grouped_state_dict(kohya_state_dict, alpha=None)