From 5bd6428fddec2b73c8de0ebc540961435e36c868 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 21 Jan 2025 16:39:59 +0000 Subject: [PATCH] Add is_state_dict_likely_in_flux_onetrainer_format() util function. --- .../flux_kohya_lora_conversion_utils.py | 12 ++++- .../flux_onetrainer_lora_conversion_utils.py | 37 ++++++++++++++++ ...t_flux_onetrainer_lora_conversion_utils.py | 44 +++++++++++++++++++ 3 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py create mode 100644 tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py diff --git a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py index 6ff0d2fa3c..1803f0dc7a 100644 --- a/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_kohya_lora_conversion_utils.py @@ -26,6 +26,14 @@ FLUX_KOHYA_TRANSFORMER_KEY_REGEX = ( # lora_te1_text_model_encoder_layers_0_mlp_fc1.lora_up.weight FLUX_KOHYA_CLIP_KEY_REGEX = r"lora_te1_text_model_encoder_layers_(\d+)_(mlp|self_attn)_(\w+)\.?.*" +# A regex pattern that matches all of the T5 keys in the Kohya FLUX LoRA format. +# Example keys: +# lora_te2_encoder_block_0_layer_0_SelfAttention_k.alpha +# lora_te2_encoder_block_0_layer_0_SelfAttention_k.dora_scale +# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_down.weight +# lora_te2_encoder_block_0_layer_0_SelfAttention_k.lora_up.weight +FLUX_KOHYA_T5_KEY_REGEX = r"lora_te2_encoder_block_(\d+)_layer_(\d+)_(DenseReluDense|SelfAttention)_(\w+)\.?.*" + def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> bool: """Checks if the provided state dict is likely in the Kohya FLUX LoRA format. @@ -34,7 +42,9 @@ def is_state_dict_likely_in_flux_kohya_format(state_dict: Dict[str, Any]) -> boo perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) """ return all( - re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) + re.match(FLUX_KOHYA_TRANSFORMER_KEY_REGEX, k) + or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) + or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) for k in state_dict.keys() ) diff --git a/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py new file mode 100644 index 0000000000..661014c357 --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/flux_onetrainer_lora_conversion_utils.py @@ -0,0 +1,37 @@ +import re +from typing import Any, Dict + +from invokeai.backend.patches.lora_conversions.flux_kohya_lora_conversion_utils import ( + FLUX_KOHYA_CLIP_KEY_REGEX, + FLUX_KOHYA_T5_KEY_REGEX, +) + +# A regex pattern that matches all of the transformer keys in the OneTrainer FLUX LoRA format. +# The OneTrainer format uses a mix of the Kohya and Diffusers formats: +# - The base model keys are in Diffusers format. +# - Periods are replaced with underscores, to match Kohya. +# - The LoRA key suffixes (e.g. .alpha, .lora_down.weight, .lora_up.weight) match Kohya. +# Example keys: +# - "lora_transformer_single_transformer_blocks_0_attn_to_k.alpha" +# - "lora_transformer_single_transformer_blocks_0_attn_to_k.dora_scale" +# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_down.weight" +# - "lora_transformer_single_transformer_blocks_0_attn_to_k.lora_up.weight" +FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX = ( + r"lora_transformer_(single_transformer_blocks|transformer_blocks)_(\d+)_(\w+)\.(.*)" +) + + +def is_state_dict_likely_in_flux_onetrainer_format(state_dict: Dict[str, Any]) -> bool: + """Checks if the provided state dict is likely in the OneTrainer FLUX LoRA format. + + This is intended to be a high-precision detector, but it is not guaranteed to have perfect precision. (A + perfect-precision detector would require checking all keys against a whitelist and verifying tensor shapes.) + + Note that OneTrainer matches the Kohya format for the CLIP and T5 models. + """ + return all( + re.match(FLUX_ONETRAINER_TRANSFORMER_KEY_REGEX, k) + or re.match(FLUX_KOHYA_CLIP_KEY_REGEX, k) + or re.match(FLUX_KOHYA_T5_KEY_REGEX, k) + for k in state_dict.keys() + ) diff --git a/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py b/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py new file mode 100644 index 0000000000..9e49ac20ab --- /dev/null +++ b/tests/backend/patches/lora_conversions/test_flux_onetrainer_lora_conversion_utils.py @@ -0,0 +1,44 @@ +import pytest + +from invokeai.backend.patches.lora_conversions.flux_onetrainer_lora_conversion_utils import ( + is_state_dict_likely_in_flux_onetrainer_format, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_dora_onetrainer_format import ( + state_dict_keys as flux_onetrainer_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_diffusers_format import ( + state_dict_keys as flux_diffusers_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_format import ( + state_dict_keys as flux_kohya_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.flux_lora_kohya_with_te1_format import ( + state_dict_keys as flux_kohya_te1_state_dict_keys, +) +from tests.backend.patches.lora_conversions.lora_state_dicts.utils import keys_to_mock_state_dict + + +def test_is_state_dict_likely_in_flux_onetrainer_format_true(): + """Test that is_state_dict_likely_in_flux_onetrainer_format() can identify a state dict in the OneTrainer + FLUX LoRA format. + """ + # Construct a state dict that is in the OneTrainer FLUX LoRA format. + state_dict = keys_to_mock_state_dict(flux_onetrainer_state_dict_keys) + + assert is_state_dict_likely_in_flux_onetrainer_format(state_dict) + + +@pytest.mark.parametrize( + "sd_keys", + [ + flux_kohya_state_dict_keys, + flux_kohya_te1_state_dict_keys, + flux_diffusers_state_dict_keys, + ], +) +def test_is_state_dict_likely_in_flux_onetrainer_format_false(sd_keys: dict[str, list[int]]): + """Test that is_state_dict_likely_in_flux_onetrainer_format() returns False for a state dict that is in the Diffusers + FLUX LoRA format. + """ + state_dict = keys_to_mock_state_dict(sd_keys) + assert not is_state_dict_likely_in_flux_onetrainer_format(state_dict)