From 206f261e45bc9d9749318f975fe15376a3c64ecd Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 22 Jan 2025 20:41:14 +0000 Subject: [PATCH] Add utils for loading FLUX OneTrainer DoRA models. --- .../patches/layers/concatenated_lora_layer.py | 3 +- .../flux_kohya_lora_conversion_utils.py | 48 ++++++- .../lora_conversions/flux_lora_constants.py | 1 + .../flux_onetrainer_lora_conversion_utils.py | 126 ++++++++++++++++++ ...t_flux_onetrainer_lora_conversion_utils.py | 33 +++++ 5 files changed, 203 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/patches/layers/concatenated_lora_layer.py b/invokeai/backend/patches/layers/concatenated_lora_layer.py index a699a47433..b2bc63a39a 100644 --- a/invokeai/backend/patches/layers/concatenated_lora_layer.py +++ b/invokeai/backend/patches/layers/concatenated_lora_layer.py @@ -2,7 +2,6 @@ from typing import Optional, Sequence import torch -from invokeai.backend.patches.layers.lora_layer import LoRALayer from invokeai.backend.patches.layers.lora_layer_base import LoRALayerBase @@ -14,7 +13,7 @@ class ConcatenatedLoRALayer(LoRALayerBase): stored as separate tensors. This class enables diffusers LoRA layers to be used in BFL FLUX models. """ - def __init__(self, lora_layers: Sequence[LoRALayer], concat_axis: int = 0): + def __init__(self, lora_layers: Sequence[LoRALayerBase], concat_axis: int = 0): super().__init__(alpha=None, bias=None) self.lora_layers = lora_layers diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 1803f0dc7a..5866878b94 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -7,6 +7,7 @@ from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict from invokeai.backend.patches.lora_conversions.flux_lora_constants import ( FLUX_LORA_CLIP_PREFIX, + FLUX_LORA_T5_PREFIX, FLUX_LORA_TRANSFORMER_PREFIX, ) from invokeai.backend.patches.model_patch_raw import ModelPatchRaw @@ -32,7 +33,7 @@ FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self # lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale # lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight # lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight -FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)\.?.*" +FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)_?(\w+)?\.?.*" def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: @@ -58,27 +59,34 @@ def lora_model_from_flux_kohya_state_dict(state_dict: Dict[str, torch.Tensor]) - grouped_state_dict[layer_name] = {} grouped_state_dict[layer_name][param_name] = value - # Split the grouped state dict into transformer and CLIP state dicts. + # Split the grouped state dict into transformer, CLIP, and T5 state dicts. transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} for layer_name, layer_state_dict in grouped_state_dict.items(): if layer_name.startswith("lora_unet"): transformer_grouped_sd[layer_name] = layer_state_dict elif layer_name.startswith("lora_te1"): clip_grouped_sd[layer_name] = layer_state_dict + elif layer_name.startswith("lora_te2"): + t5_grouped_sd[layer_name] = layer_state_dict else: raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.") # Convert the state dicts to the InvokeAI format. transformer_grouped_sd = _convert_flux_transformer_kohya_state_dict_to_invoke_format(transformer_grouped_sd) clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd) + t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd) # Create LoRA layers. layers: dict[str, BaseLayerPatch] = {} - for layer_key, layer_state_dict in transformer_grouped_sd.items(): - layers[FLUX_LORA_TRANSFORMER_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) - for layer_key, layer_state_dict in clip_grouped_sd.items(): - layers[FLUX_LORA_CLIP_PREFIX + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + for model_prefix, grouped_sd in [ + (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd), + (FLUX_LORA_CLIP_PREFIX, clip_grouped_sd), + (FLUX_LORA_T5_PREFIX, t5_grouped_sd), + ]: + for layer_key, layer_state_dict in grouped_sd.items(): + layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) # Create and return the LoRAModelRaw. return ModelPatchRaw(layers=layers) @@ -133,3 +141,31 @@ def _convert_flux_transformer_kohya_state_dict_to_invoke_format(state_dict: Dict raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") return converted_dict + + +def _convert_flux_t5_kohya_state_dict_to_invoke_format(state_dict: Dict[str, T]) -> Dict[str, T]: + """Converts a T5 LoRA state dict from the Kohya FLUX LoRA format to LoRA weight format used internally by + InvokeAI. + + Example key conversions: + + "lora_te2_encoder_block_0_layer_0_SelfAttention_k" -> "encoder.block.0.layer.0.SelfAttention.k" + "lora_te2_encoder_block_0_layer_1_DenseReluDense_wi_0" -> "encoder.block.0.layer.1.DenseReluDense.wi.0" + """ + + def replace_func(match: re.Match[str]) -> str: + s = f"{match.group(1)}.{match.group(2)}.{match.group(3)}.{match.group(4)}" + if match.group(5): + s += f".{match.group(5)}" + return "encoder.block." + s + + converted_dict: dict[str, T] = {} + for k, v in state_dict.items(): + match = re.match(FLUX_KOHYA_T5_KEY_REGEX, k) + if match: + new_key = re.sub(FLUX_KOHYA_T5_KEY_REGEX, replace_func, k) + converted_dict[new_key] = v + else: + raise ValueError(f"Key '{k}' does not match the expected pattern for FLUX LoRA weights.") + + return converted_dict diff --git a/invokeai/backend/patches/lora_conversions/flux_lora_constants.py b/invokeai/backend/patches/lora_conversions/flux_lora_constants.py index 4f854d1442..2857514462 100644 --- a/invokeai/backend/patches/lora_conversions/flux_lora_constants.py +++ b/invokeai/backend/patches/lora_conversions/flux_lora_constants.py @@ -1,3 +1,4 @@ # Prefixes used to distinguish between transformer and CLIP text encoder keys in the FLUX InvokeAI LoRA format. FLUX_LORA_TRANSFORMER_PREFIX = "lora_transformer-" FLUX_LORA_CLIP_PREFIX = "lora_clip-" +FLUX_LORA_T5_PREFIX = "lora_t5-" diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py index 661014c357..0413f0ef49 100644 --- a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py @@ -1,10 +1,29 @@ import re from typing import Any, Dict +import torch + +from invokeai.backend.patches.layers.base_layer_patch import BaseLayerPatch +from invokeai.backend.patches.layers.utils import any_lora_layer_from_state_dict +from invokeai.backend.patches.lora_conversions.flux_diffusers_lora_conversion_utils import ( + lora_layers_from_flux_diffusers_grouped_state_dict, +) from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( FLUX_KOHYA_CLIP_KEY_REGEX, FLUX_KOHYA_T5_KEY_REGEX, + _convert_flux_clip_kohya_state_dict_to_invoke_format, + _convert_flux_t5_kohya_state_dict_to_invoke_format, ) +from invokeai.backend.patches.lora_conversions.flux_lora_constants import ( + FLUX_LORA_CLIP_PREFIX, + FLUX_LORA_T5_PREFIX, +) +from invokeai.backend.patches.lora_conversions.kohya_key_utils import ( + INDEX_PLACEHOLDER, + ParsingTree, + insert_periods_into_kohya_key, +) +from invokeai.backend.patches.model_patch_raw import ModelPatchRaw # A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format. # The OneTrainer format uses a mix of the Kohya and Diffusers formats: @@ -35,3 +54,110 @@ def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) - or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() ) + + +def lora_model_from_flux_onetrainer_state_dict(state_dict: Dict[str, torch.Tensor]) -> ModelPatchRaw: # type: ignore + # Group keys by layer. + grouped_state_dict: dict[str, dict[str, torch.Tensor]] = {} + for key, value in state_dict.items(): + layer_name, param_name = key.split(".", 1) + if layer_name not in grouped_state_dict: + grouped_state_dict[layer_name] = {} + grouped_state_dict[layer_name][param_name] = value + + # Split the grouped state dict into transformer, CLIP, and T5 state dicts. + transformer_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + clip_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + t5_grouped_sd: dict[str, dict[str, torch.Tensor]] = {} + for layer_name, layer_state_dict in grouped_state_dict.items(): + if layer_name.startswith("lora_transformer"): + transformer_grouped_sd[layer_name] = layer_state_dict + elif layer_name.startswith("lora_te1"): + clip_grouped_sd[layer_name] = layer_state_dict + elif layer_name.startswith("lora_te2"): + t5_grouped_sd[layer_name] = layer_state_dict + else: + raise ValueError(f"Layer '{layer_name}' does not match the expected pattern for FLUX LoRA weights.") + + # Convert the state dicts to the InvokeAI format. + clip_grouped_sd = _convert_flux_clip_kohya_state_dict_to_invoke_format(clip_grouped_sd) + t5_grouped_sd = _convert_flux_t5_kohya_state_dict_to_invoke_format(t5_grouped_sd) + + # Create LoRA layers. + layers: dict[str, BaseLayerPatch] = {} + for model_prefix, grouped_sd in [ + # (FLUX_LORA_TRANSFORMER_PREFIX, transformer_grouped_sd), + (FLUX_LORA_CLIP_PREFIX, clip_grouped_sd), + (FLUX_LORA_T5_PREFIX, t5_grouped_sd), + ]: + for layer_key, layer_state_dict in grouped_sd.items(): + layers[model_prefix + layer_key] = any_lora_layer_from_state_dict(layer_state_dict) + + # Handle the transformer. + transformer_layers = _convert_flux_transformer_onetrainer_state_dict_to_invoke_format(transformer_grouped_sd) + layers.update(transformer_layers) + + # Create and return the LoRAModelRaw. + return ModelPatchRaw(layers=layers) + + +# This parsing tree was generated by calling `generate_kohya_parsing_tree_from_keys()` on the keys in +# flux_lora_diffusers_format.py. +flux_transformer_kohya_parsing_tree: ParsingTree = { + "transformer": { + "single_transformer_blocks": { + INDEX_PLACEHOLDER: { + "attn": {"to_k": {}, "to_q": {}, "to_v": {}}, + "norm": {"linear": {}}, + "proj_mlp": {}, + "proj_out": {}, + } + }, + "transformer_blocks": { + INDEX_PLACEHOLDER: { + "attn": { + "add_k_proj": {}, + "add_q_proj": {}, + "add_v_proj": {}, + "to_add_out": {}, + "to_k": {}, + "to_out": {INDEX_PLACEHOLDER: {}}, + "to_q": {}, + "to_v": {}, + }, + "ff": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}}, + "ff_context": {"net": {INDEX_PLACEHOLDER: {"proj": {}}}}, + "norm1": {"linear": {}}, + "norm1_context": {"linear": {}}, + } + }, + } +} + + +def _convert_flux_transformer_onetrainer_state_dict_to_invoke_format( + state_dict: Dict[str, Dict[str, torch.Tensor]], +) -> dict[str, BaseLayerPatch]: + """Converts a FLUX transformer LoRA state dict from the OneTrainer FLUX LoRA format to the LoRA weight format used + internally by InvokeAI. + """ + + # Step 1: Convert the Kohya-style keys with underscores to classic keys with periods. + # Example: + # "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" -> "transformer.single_transformer_blocks.0.attn.to_k.lora_down.weight" + lora_prefix = "lora_" + lora_prefix_length = len(lora_prefix) + kohya_state_dict: dict[str, Dict[str, torch.Tensor]] = {} + for key in state_dict.keys(): + # Remove the "lora_" prefix. + assert key.startswith(lora_prefix) + new_key = key[lora_prefix_length:] + + # Add periods to the Kohya-style module keys. + new_key = insert_periods_into_kohya_key(new_key, flux_transformer_kohya_parsing_tree) + + # Replace the old key with the new key. + kohya_state_dict[new_key] = state_dict[key] + + # Step 2: Convert diffusers module names to the BFL module names. + return lora_layers_from_flux_diffusers_grouped_state_dict(kohya_state_dict, alpha=None) diff --git a/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py index 9e49ac20ab..cf8a27d5ad 100644 --- a/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py +++ b/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py @@ -1,7 +1,13 @@ import pytest +from invokeai.backend.patches.lora_conversions.flux_lora_constants import ( + FLUX_LORA_CLIP_PREFIX, + FLUX_LORA_T5_PREFIX, + FLUX_LORA_TRANSFORMER_PREFIX, +) from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( is_state_dict_likely_in_flux_onetrainer_format, + lora_model_from_flux_onetrainer_state_dict, ) from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( state_dict_keys as flux_onetrainer_state_dict_keys, @@ -42,3 +48,30 @@ def test_is_state_dict_likely_in_flux_onetrainer_format_false(sd_keys: dict[str, """ state_dict = keys_to_mock_state_dict(sd_keys) assert not is_state_dict_likely_in_flux_onetrainer_format(state_dict) + + +def test_lora_model_from_flux_onetrainer_state_dict(): + state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys) + + lora_model = lora_model_from_flux_onetrainer_state_dict(state_dict) + + # Check that the model has the correct number of LoRA layers. + expected_lora_layers: set[str] = set() + for k in flux_onetrainer_state_dict_keys: + k = k.replace(".lora_up.weight", "") + k = k.replace(".lora_down.weight", "") + k = k.replace(".alpha", "") + k = k.replace(".dora_scale", "") + expected_lora_layers.add(k) + # Drop the K/V/proj_mlp weights because these are all concatenated into a single layer in the BFL format (we keep + # the Q weights so that we count these layers once). + concatenated_weights = ["to_k", "to_v", "proj_mlp", "add_k_proj", "add_v_proj"] + expected_lora_layers = {k for k in expected_lora_layers if not any(w in k for w in concatenated_weights)} + + assert len(lora_model.layers) == len(expected_lora_layers) + + # Check that all of the layers have the expected prefix. + assert all( + k.startswith((FLUX_LORA_TRANSFORMER_PREFIX, FLUX_LORA_CLIP_PREFIX, FLUX_LORA_T5_PREFIX)) + for k in lora_model.layers.keys() + )